In [1]:
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-wl5il64j
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-wl5il64j
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
import os
import torch
import torch.nn as nn 
import numpy as np
import h5py
import torchvision.transforms as transforms
from torchvision.models import resnet18, ResNet18_Weights
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import argparse
import yaml
import pandas as pd
from scipy import ndimage
from timm.models.vision_transformer import vit_base_patch16_224
from transformers import LongformerModel, LongformerConfig
import clip

2025-05-24 09:46:26.616924: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1748079986.640598     157 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1748079986.647393     157 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
p = "/kaggle/input/scenarios/scenarios_small_100.txt"
with open(p) as file:
    lines = [line.strip() for line in file]
scenarios = lines
scenarios_tokens = clip.tokenize(scenarios)

def pad_collate(batch):
    meta, img, vego, angle, dist = zip(*batch)
    from torch.nn.utils.rnn import pad_sequence
    m_pad = pad_sequence(meta, batch_first=True, padding_value=0)
    i_pad = pad_sequence(img, batch_first=True, padding_value=0)
    vego_pad = pad_sequence(vego, batch_first=True, padding_value=0)
    a_pad = pad_sequence(angle, batch_first=True, padding_value=0)
    d_pad = pad_sequence(dist, batch_first=True, padding_value=0)
    return m_pad, i_pad, vego_pad, a_pad, d_pad, None, None, None, None, None


In [5]:

class CommaDataset(Dataset):
    def __init__(
        self,
        dataset_type="train",
        use_transform=False,
        multitask="angle",
        ground_truth="desired",
        return_full=False, 
        dataset_path ="/kaggle/input/filtered-chunk1" ,
        dataset_fraction=1.0
    ):
        assert dataset_type in ["train", "val", "test"]
        if dataset_type == "val":
            dataset_type = "test" 
        if dataset_type == "test":
            dataset_type = "val" 
        
        self.dataset_type = dataset_type
        self.dataset_fraction = dataset_fraction
        self.max_len = 240
        self.ground_truth = ground_truth
        self.multitask = multitask
        self.use_transform = use_transform
        self.return_full = return_full
        self.normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
        self.resize = transforms.Resize((224,224))
        #/data1/shared/jessica/data1/data/
        data_path = f"{dataset_path}/filtered_chunk1_{dataset_type}.hdf5" if ground_truth == "regular" else f"{dataset_path}/filtered_chunk1_{dataset_type}.hdf5"
        self.people_seqs = []
        self.h5_file = h5py.File(data_path, "r")
        corrupt_idx = 62
        self.keys = list(self.h5_file.keys())
        if dataset_type == "train":
            self.keys.pop(corrupt_idx)
           
    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        person_seq = {}
        seq_key  = self.keys[idx]
        keys_ = self.h5_file[seq_key].keys()#'angle', 'brake', 'dist', 'gas', 'image', 'time', 'vEgo'
        file = self.h5_file
        
        for key in keys_:                        
            seq = file[seq_key][key][()]
            seq = seq if len(seq) <= 241 else seq[1::5]
            person_seq[key] = torch.from_numpy(np.array(seq[0:self.max_len]).astype(float)).type(torch.float32)
        sequences = person_seq
        distances = sequences['dist']
        distances = ndimage.median_filter(distances, size=128, mode='nearest')

        steady_state = ~np.array(sequences['gaspressed']).astype(bool) & ~np.array(sequences['brakepressed']).astype(bool) & ~np.array(sequences['leftBlinker']).astype(bool) & ~np.array(sequences['rightBlinker']).astype(bool)
        last_idx = 0
        desired_gap = np.zeros(distances.shape)

        for i in range(len(steady_state)-1):
            if steady_state[i] == True:
                desired_gap[last_idx:i] = int(distances[i])
                last_idx = i
        desired_gap[-12:] = distances[-12:].mean().item()

        distances = sequences['dist'] if self.ground_truth else desired_gap
        images = sequences['image']
        images = images[:,0:160, :,:]#crop the image to remove the view of the inside car console
        images = images.permute(0,3,1,2)
        if not self.return_full:
            images = self.normalize(images/255.0)
            
        else:
            images = images/255.0
        images = self.resize(images)
        images_cropped = images
        intervention = np.array(sequences['gaspressed']).astype(bool) | np.array(sequences['brakepressed']).astype(bool) 
        res = images_cropped, images_cropped,  sequences['vEgo'],  sequences['angle'], distances
        if self.return_full: 
            return images_cropped,  sequences['vEgo'],  sequences['angle'], distances, np.array(sequences['gaspressed']).astype(bool),  np.array(sequences['brakepressed']).astype(bool) , np.array(sequences['CruiseStateenabled']).astype(bool)
        if self.multitask == "distance":
            res = images_cropped, images_cropped, sequences['vEgo'], distances, sequences['angle']
        if self.multitask == "intervention":
            res = images_cropped, images_cropped, sequences['vEgo'], distances, torch.tensor(np.array(sequences['gaspressed']).astype(bool) | np.array(sequences['brakepressed']).astype(bool))
        return res 

In [7]:
def pad_to_window_size_local(input_ids: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.Tensor,
                             one_sided_window_size: int, pad_token_id: int):
    
    '''A helper function to pad tokens and mask to work with the sliding_chunks implementation of Longformer self-attention.
    Based on _pad_to_window_size from https://github.com/huggingface/transformers:
    https://github.com/huggingface/transformers/blob/71bdc076dd4ba2f3264283d4bc8617755206dccd/src/transformers/models/longformer/modeling_longformer.py#L1516
    Input:
        input_ids = torch.Tensor(bsz x seqlen): ids of wordpieces
        attention_mask = torch.Tensor(bsz x seqlen): attention mask
        one_sided_window_size = int: window size on one side of each token
        pad_token_id = int: tokenizer.pad_token_id
    Returns
        (input_ids, attention_mask) padded to length divisible by 2 * one_sided_window_size
    '''
    w = 2 * one_sided_window_size
    seqlen = input_ids.size(1)
    padding_len = (w - seqlen % w) % w
    input_ids = F.pad(input_ids.permute(0, 2, 1), (0, padding_len), value=pad_token_id).permute(0, 2, 1)
    attention_mask = F.pad(attention_mask, (0, padding_len), value=False)  # no attention on the padding tokens
    position_ids = F.pad(position_ids, (1, padding_len), value=False)  # no attention on the padding tokens
    return input_ids, attention_mask, position_ids


In [8]:

''' This is a modified version of the Longformer model from Huggingface. '''
class VTNLongformerModel(LongformerModel):
    def __init__(self,
                 embed_dim=2048,
                 max_position_embeddings=2 * 60 * 60,
                 num_attention_heads=16,
                 num_hidden_layers=3,
                 attention_mode='sliding_chunks',
                 pad_token_id=-1,
                 attention_window=None,
                 intermediate_size=3072,
                 attention_probs_dropout_prob=0.4,
                 hidden_dropout_prob=0.5):

        self.config = LongformerConfig()
        self.config.attention_mode = attention_mode
        self.config.intermediate_size = intermediate_size
        self.config.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.config.hidden_dropout_prob = hidden_dropout_prob
        self.config.attention_dilation = [1, ] * num_hidden_layers
        self.config.attention_window = [256, ] * num_hidden_layers if attention_window is None else attention_window
        self.config.num_hidden_layers = num_hidden_layers
        self.config.num_attention_heads = num_attention_heads
        self.config.pad_token_id = pad_token_id
        self.config.max_position_embeddings = max_position_embeddings
        self.config.hidden_size = embed_dim
        super(VTNLongformerModel, self).__init__(self.config, add_pooling_layer=False)
        self.embeddings.word_embeddings = None  # to avoid distributed error of unused parameters


In [9]:
class VTN(nn.Module):
    def __init__(self, multitask="angle", backbone="resnet", device="cuda", multitask_param=True, concept_features=False, train_concepts=False, return_concepts=False):
        super(VTN, self).__init__()
        self.device = device
        self.return_concepts = return_concepts
        self.train_concepts = train_concepts
    
        self._construct_network(multitask, backbone, multitask_param, concept_features)

    def _construct_network(self, multitask, backbone, multitask_param, concept_features):
        clip_model, clip_preprocess = clip.load("ViT-B/32", device="cpu")
        self.clip_model = clip_model
        self.clip_preprocess = clip_preprocess
        self.clip_model.eval()
        self.concept_features = concept_features
        self.backbone_name = backbone

        additional_feat_size = 3 if not concept_features else len(scenarios)+3

        if backbone == "vit":
            print("using vit backbone")
            self.backbone = vit_base_patch16_224(pretrained=True,num_classes=0,drop_path_rate=0.0,drop_rate=0.0)
            num_attention_heads=3 if not concept_features else 7
            mlp_size = 768+additional_feat_size #image feature size + previous sensor feature size 
            embed_dim = 768+additional_feat_size
        elif backbone== "resnet":
            print("using resnet backbone")
            resnet = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
            self.backbone = torch.nn.Sequential(*list(resnet.children())[:-1])
            embed_dim = 512+additional_feat_size #image feature size + previous sensor feature size 
            num_attention_heads=5 #if not concept_features else 6
            mlp_size = 512+additional_feat_size #image feature size + previous sensor feature size 
        elif backbone == "none" and concept_features:
            print("using concept features")
            embed_dim = len(scenarios)+3
            num_attention_heads=1
            mlp_size = len(scenarios)+3
        elif backbone == "clip":
            print("using clip backbone")
            self.backbone = lambda x: clip_model.encode_image(x)
            embed_dim = 512+3
            num_attention_heads=5
            mlp_size = 512+3

        self.multitask = multitask
        self.multitask_param = multitask_param
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        #self.pe = PositionalEncoding(embed_dim) #Did not add positional encoding because VTN paper found better results without

        self.temporal_encoder = VTNLongformerModel(
            embed_dim=embed_dim,
            max_position_embeddings=288,
            num_attention_heads=num_attention_heads,
            num_hidden_layers=3,
            attention_mode='sliding_chunks',
            pad_token_id=-1,
            attention_window=[8, 8, 8],
            intermediate_size=3072,
            attention_probs_dropout_prob=0.3,
            hidden_dropout_prob=0.3)
        num_classes = 1
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(mlp_size),
            nn.Linear(mlp_size, mlp_size),
            nn.GELU(),
            nn.Dropout(0.5),
            nn.Linear(mlp_size, num_classes)
        )
        self.mlp_head_2 = nn.Sequential(
                nn.LayerNorm(mlp_size),
                nn.Linear(mlp_size, mlp_size),
                nn.GELU(),
                nn.Dropout(0.5),
                nn.Linear(mlp_size, num_classes)
            )
        if self.multitask and self.multitask_param:
            self.multitask_param_angle = nn.Parameter(torch.tensor([1.0]))
            self.multitask_param_dist = nn.Parameter(torch.tensor([1.0]))

    def forward(self, img, angle, distance, vego):
        # we need to roll the previous sensor features, so that we do not include the step that we want to predict
        # we also substitude empty 0th entry then with 1st entry
        x = img
        if self.concept_features:
            s = img.shape#[batch_size, seq_len, h,w,c]
            logits_per_image, logits_per_text = self.clip_model(img.reshape((img.shape[0]*img.shape[1], img.shape[2], img.shape[3], img.shape[4])), scenarios_tokens.to(x.device))
            probs = logits_per_image.softmax(dim=-1)
            probs = logits_per_image.reshape((int(img.shape[0]), int(logits_per_image.shape[0]/img.shape[0]), -1))
            if not self.train_concepts: probs = probs.detach()

        angle = torch.roll(angle, shifts=1, dims=1)
        angle[:,0] = angle[:,1]
        distance = torch.roll(distance, shifts=1, dims=1)
        distance[:,0] = distance[:,1]
        vego = torch.roll(vego, shifts=1, dims=1)
        vego[:,0] = vego[:,1]

        # spatial backbone
        B, F, C, H, W = x.shape
        if self.backbone_name != "none":
            x = x.reshape(B * F, C, H, W)
            x = self.backbone(x)
            if self.backbone_name == "clip":
                if not self.train_concepts:
                    x = x.detach()
            x = x.reshape(B, F, -1)
            

        #concatenate the sensor features 
        if self.concept_features:
            x = torch.cat([x, probs], dim=-1) if self.backbone_name != 'none' else probs
        x = torch.cat((x, angle.unsqueeze(-1)), dim=-1)
        x = torch.cat((x, distance.unsqueeze(-1)), dim=-1)
        x = torch.cat((x, vego.unsqueeze(-1)), dim=-1)

        # temporal encoder (Longformer)
        B, D, E = x.shape
        attention_mask = torch.ones((B, D), dtype=torch.long, device=x.device)
        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_tokens, x), dim=1)

        cls_atten = torch.ones(1).expand(B, -1).to(x.device)
        attention_mask = torch.cat((attention_mask, cls_atten), dim=1)
        attention_mask[:, 0] = 2 # initialize the start with a special number

        
        x, attention_mask, _ = pad_to_window_size_local(
            x,
            attention_mask,
            x,#position_ids, in case we wanted them
            self.temporal_encoder.config.attention_window[0],
            self.temporal_encoder.config.pad_token_id)

        token_type_ids = torch.zeros(x.size()[:-1], dtype=torch.long, device=x.device)
        token_type_ids[:, 0] = 1

        x = self.temporal_encoder(input_ids=None,
                                  attention_mask=attention_mask,
                                  token_type_ids=token_type_ids,
                                  position_ids=None,#position_ids,
                                  inputs_embeds=x,
                                  output_attentions=True,
                                  output_hidden_states=None,
                                  return_dict=True)
        
        # MLP head
        attentions = x['attentions']
        x = x["last_hidden_state"]
        

        if self.multitask:
            x2 = self.mlp_head_2(x)
        x = self.mlp_head(x)
        if self.multitask != "multitask":
            res = x[:,1:F+1,:], attentions
            return res, attentions # we want to exclude the starting token since we don't have any previous knowledge about it 
        else:
           # res = (x[:,1:F+1,:], x2[:,1:F+1,:],self.multitask_param_angle, self.multitask_param_dist), attentions
            res = (x[:,1:F+1,:], x2[:,1:F+1,:], self.multitask_param_angle, self.multitask_param_dist)
            return res, attentions # we want to exclude the starting token since we don't have any previous knowledge about it 

In [11]:
class LaneModule(pl.LightningModule):
    def __init__(self, model, bs, multitask="angle", dataset="comma", time_horizon=1, ground_truth="desired", intervention=False, dataset_path=None, dataset_fraction=1.0):
        super(LaneModule, self).__init__()
        self.dataset_fraction = dataset_fraction
        self.model = model
        self.dataset = dataset
        self.ground_truth = ground_truth
        self.intervention = intervention
        self.dataset_path = dataset_path
        self.num_workers = 4
        self.multitask = multitask
        self.bs = bs
        self.time_horizon = time_horizon
        self.loss = self.mse_loss
        #self.save_hyperparameters(ignore=['model'])
        self.bce_loss = nn.BCELoss()

    def forward(self, x, angle, distance, vego):
        return self.model(x, angle, distance, vego)

    def mse_loss(self, input, target, mask, reduction="mean"):
        input = input.float()
        target = target.float()

        out = (input[~mask]-target[~mask])**2
        return out.mean() if reduction == "mean" else out 

    def calculate_loss(self, logits, angle, distance):
        sm = nn.Softmax(dim=1)
        if self.multitask == "multitask":
            logits_angle, logits_dist, param_angle, param_dist = logits
            mask = distance.squeeze() == 0.0
            if not self.intervention:
                loss_angle = torch.sqrt(self.loss(logits_angle.squeeze(), angle.squeeze(), mask))
            else: 
                angle, distance = distance, angle
                mask = distance.squeeze() == 0.0
                loss_angle = self.bce_loss(sm(logits_angle.float()).squeeze()[~mask], angle.float().squeeze()[~mask])
            loss_distance = torch.sqrt(self.loss(logits_dist.squeeze(), distance.squeeze(), mask))
            if loss_angle.isnan() or loss_distance.isnan():
                print("ERROR")
            loss = loss_angle, loss_distance
            self.log_dict({"train_loss_angle": loss_angle.detach()}, on_epoch=True, batch_size=self.bs)
            self.log_dict({"train_loss_distance": loss_distance.detach()}, on_epoch=True, batch_size=self.bs)
            return loss_angle, loss_distance, param_angle, param_dist
        else:
            mask = distance.squeeze() == 0.0
            loss = torch.sqrt(self.loss(logits.squeeze(), angle.squeeze(), mask))
            return loss

    def training_step(self, batch, batch_idx):
        _, image_array, vego, angle, distance, m_lens, i_lens, s_lens, a_lens, d_lens = batch
        logits, attns = self(image_array, angle, distance, vego)
        loss = self.calculate_loss(logits, angle, distance)
        if self.multitask == "multitask":
            loss_angle, loss_dist, param_angle, param_dist = loss
            param_angle, param_dist = 0.3, 0.7
            loss = (param_angle * loss_angle) + (param_dist * loss_dist)
            self.log_dict({"val_loss_dist": loss_dist}, on_epoch=True, batch_size=self.bs)
            self.log_dict({"val_loss_angle": loss_angle}, on_epoch=True, batch_size=self.bs)
        self.log_dict({"train_loss": loss}, on_epoch=True, batch_size=self.bs)
        return loss

    def predict_step(self, batch, batch_idx):
        _, image_array, vego, angle, distance, m_lens, i_lens, s_lens, a_lens, d_lens = batch
        if self.time_horizon > 1:
            logits_all = []
            for i in range(self.time_horizon, vego.shape[1], self.time_horizon):
                for j in range(self.time_horizon):
                    input_ids_img, input_ids_vego, input_ids_angle, input_ids_distance = image_array[:,0:i+j, :, :, :], vego[:,0:i+j], angle[:,0:i+j], distance[:,0:i+j]
                    if self.multitask == "angle" and len(logits_all) > 0:
                        angle[:,i+j] = torch.tensor(logits_all)[-1]
                    if self.multitask == "distance" and len(logits_all) > 0:
                        distance[:,i+j] = torch.tensor(logits_all)[-1]
                    if self.multitask == "multitask":
                        logits, attns = self(input_ids_img, input_ids_angle, input_ids_distance, input_ids_vego)
                        logits = logits[0][:, -1], logits[1][:, -1]
                    else:
                        logits, attns = self(input_ids_img, input_ids_angle, input_ids_distance, input_ids_vego)[:, -1]
                    logits_all.append(logits)
            return torch.stack(logits_all), angle[:,self.time_horizon:], distance[:,self.time_horizon:]

        
        logits, attns = self(image_array, angle, distance, vego)
        return logits, angle, distance

    def validation_step(self, batch, batch_idx):
        _, image_array, vego, angle, distance, m_lens, i_lens, s_lens, a_lens, d_lens = batch
        logits, attns = self(image_array, angle, distance, vego)
        loss = self.calculate_loss(logits, angle, distance)
        if self.multitask == "multitask":
            loss_angle, loss_dist, param_angle, param_dist = loss
            param_angle, param_dist = 0.3, 0.7
            loss = (param_angle * loss_angle) + (param_dist * loss_dist)
            self.log_dict({"val_loss_dist": loss_dist}, on_epoch=True, batch_size=self.bs)
            self.log_dict({"val_loss_angle": loss_angle}, on_epoch=True, batch_size=self.bs)
        self.log_dict({"val_loss": loss}, on_epoch=True, batch_size=self.bs)
        
        return loss

    def test_step(self, batch, batch_idx):
        _, image_array, vego, angle, distance, m_lens, i_lens, s_lens, a_lens, d_lens = batch
        if self.time_horizon > 1:
            logits_all = []
            for i in range(self.time_horizon,vego.shape[1], self.time_horizon):
                for j in range(self.time_horizon+ 1 ):
                    input_ids_img, input_ids_vego, input_ids_angle, input_ids_distance = image_array[:,0:i+j, :, :, :], vego[:,0:i+j], angle[:,0:i+j], distance[:,0:i+j]
                    if self.multitask == "angle":
                        angle[:,i+j] = logits[:,-1]
                    if self.multitask == "distance":
                        distance[:,i+j] = input_ids_distance[:,-1]
                    logits, attns = self(input_ids_img, input_ids_angle, input_ids_distance, input_ids_vego)[:, -1]
                    logits_all.append(logits)
            loss = self.calculate_loss(torch.stack(logits_all), angle[:,self.time_horizon:], distance[:,self.time_horizon:])
            self.log_dict({"test_loss": loss}, on_epoch=True, batch_size=self.bs)
            return loss
    
        _, image_array, vego, angle, distance, m_lens, i_lens, s_lens, a_lens, d_lens = batch
        logits, attns = self(image_array, angle, distance, vego)
        loss = self.calculate_loss(logits, angle, distance)
        
        if self.multitask == "multitask":
            loss_angle, loss_dist, param_angle, param_dist = loss
            param_angle, param_dist = 0.3, 0.7
            loss = (param_angle * loss_angle) + (param_dist * loss_dist)
            self.log_dict({"test_loss_dist": loss_dist}, on_epoch=True, batch_size=self.bs)
            self.log_dict({"test_loss_angle": loss_angle}, on_epoch=True, batch_size=self.bs)
        self.log_dict({"test_loss": loss}, on_epoch=True, batch_size=self.bs)
        return loss
    def train_dataloader(self):
        return self.get_dataloader(dataset_type="train")

    def val_dataloader(self):
        return self.get_dataloader(dataset_type="val")

    def test_dataloader(self):
        return self.get_dataloader(dataset_type="test")

    def predict_dataloader(self):
        return self.get_dataloader(dataset_type="test")

    def configure_optimizers(self):
        g_opt = torch.optim.Adam(self.model.parameters(), lr=1e-5, weight_decay=1e-5)
        return g_opt

    def get_dataloader(self, dataset_type):
        if self.dataset == "once":
            ds = ONCEDataset(dataset_type=dataset_type, multitask=self.multitask) 
        elif self.dataset == "comma":
            ds = CommaDataset(dataset_type=dataset_type, multitask=self.multitask if not self.intervention else "intervention", ground_truth=self.ground_truth, dataset_path=self.dataset_path, dataset_fraction=self.dataset_fraction)
        elif self.dataset == 'nuscenes':
            ds = NUScenesDataset(dataset_type=dataset_type, multitask=self.multitask if not self.intervention else "intervention", ground_truth=self.ground_truth, max_len=20, dataset_path=self.dataset_path, dataset_fraction=self.dataset_fraction)
        return DataLoader(ds, batch_size=self.bs, num_workers=self.num_workers, collate_fn=pad_collate)
        

    

In [12]:
class Chunk1DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=4):
        super().__init__()
        self.batch_size=batch_size
    def setup(self, stage=None):
        self.train_ds = CommaDataset('train')
        self.val_ds   = CommaDataset('val')
        self.test_ds  = CommaDataset('test')
    def train_dataloader(self): return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)
    def val_dataloader(self):   return DataLoader(self.val_ds,   batch_size=self.batch_size)
    def test_dataloader(self):  return DataLoader(self.test_ds,  batch_size=self.batch_size)


In [13]:
def save_preds(logits, target, save_name, p):
    b, s = target.shape
    df = pd.DataFrame()
    df['logits'] = logits.squeeze().reshape(b*s).tolist()
    df['target'] = target.squeeze().reshape(b*s).tolist()
    df.to_csv(f'{p}/{save_name}.csv', mode='a', index=False, header=False)

'''Define the argument parser'''
def get_arg_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument('-task', default='angle', type=str)
    parser.add_argument('-train', action='store_true')
    parser.add_argument('-test', action='store_true')
    parser.add_argument('-gpu_num', default=0, type=int)
    parser.add_argument('-dataset', default='comma', type=str)
    parser.add_argument('-dataset_path', default='/kaggle/input/filtered-chunk1', type=str)
    parser.add_argument('-bs', default=8, type=int)
    parser.add_argument('-max_epochs', default=10, type=int)
    parser.add_argument('-ground_truth', default='desired', type=str)
    parser.add_argument('-new_version', action='store_true')
    parser.add_argument('-intervention_prediction', action='store_true')
    parser.add_argument('-dev_run', action='store_true')
    parser.add_argument('-backbone', default='resnet', type=str, choices=['resnet', 'vit'], help='Backbone model type')
    parser.add_argument('-concept_features', action='store_true', help='Use concept features')
    parser.add_argument('-train_concepts', action='store_true', help='Train concept features')
    parser.add_argument('-dataset_fraction', default=1.0, type=float, help='Fraction of dataset to use')
    return parser


In [14]:
if __name__ == "__main__":
    from torch.nn import DataParallel
    # Configurazione manuale per il primo training
    args_list = [
        '-train',
        '-task', 'multitask',  # o 'angle'/'distance'
        '-dataset', 'comma',
        '-dataset_path', '/kaggle/input/filtered-chunk1',
        '-bs', '2',  # Batch size ridotto per iniziare
        '-max_epochs', '1',
        '-backbone', 'resnet',
        '-ground_truth', 'desired',
        '-concept_features',  # Aggiungi se necessario
        '-train_concepts'    # Aggiungi se necessario
    ]
    
    parser = get_arg_parser()
    args = parser.parse_args(args=args_list)
    
    # Setup iniziale
   # torch.cuda.empty_cache()
   # os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:50"
   # if torch.cuda.device_count() > 0 and torch.cuda.get_device_capability()[0] >= 7:
   #     torch.set_float32_matmul_precision('high')

    # Inizializzazione modello
   # device = "cuda:0" if torch.cuda.is_available() else "cpu"
    model = VTN(
        multitask=args.task,
        backbone=args.backbone,
        concept_features=args.concept_features,
        device="cpu",
        train_concepts=args.train_concepts
    )

    if torch.cuda.is_available() and torch.cuda.device_count() > 1:
        device = torch.device("cuda")
        model = DataParallel(model, device_ids=[0,1])  # scegli 0,1,2… in base alle GPU
        model.to(device)
    else:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model.to(device)

    module = LaneModule(model, multitask=args.task, dataset = args.dataset, bs=args.bs, ground_truth=args.ground_truth, intervention=args.intervention_prediction, dataset_path=args.dataset_path, dataset_fraction=args.dataset_fraction)


    # Configurazione checkpoint e logger
    ckpt_pth = f"/kaggle/working/ckpts_final_{args.dataset}_{args.task}_{args.backbone}_{args.concept_features}_{args.dataset_fraction}"
    checkpoint_callback = ModelCheckpoint(
        dirpath=ckpt_pth,
        filename='best-{epoch}-{val_loss:.2f}',
        save_top_k=1,
        monitor="val_loss",
        mode="min"
    )
    logger = TensorBoardLogger(save_dir=ckpt_pth)

    # Configurazione trainer
    trainer = pl.Trainer(
       accelerator="gpu",
       devices=1,                # usa 2 GPU
       precision="16-mixed",             # mixed precision FP16 per dimezzare l’uso di VRAM
       max_epochs=args.max_epochs,
       logger=logger,
       callbacks=[
            TQDMProgressBar(refresh_rate=10),
            checkpoint_callback,
            EarlyStopping(monitor="val_loss", patience=3)
        ],
        enable_checkpointing=True
    )

    # Training
    trainer.fit(module)
    
    # Salvataggio configurazione
    save_path = "/".join(checkpoint_callback.best_model_path.split("/")[:-1])
    with open(f'{save_path}/hparams.yaml', 'w') as f:
        yaml.dump(vars(args), f)  # Nota: usiamo vars() per convertire Namespace in dict
    
    print(f"Training completato! Checkpoint salvato in: {checkpoint_callback.best_model_path}")

using resnet backbone


/usr/local/lib/python3.11/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /kaggle/working/ckpts_final_comma_multitask_resnet_True_1.0 exists and is not empty.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (35) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Training completato! Checkpoint salvato in: /kaggle/working/ckpts_final_comma_multitask_resnet_True_1.0/best-epoch=0-val_loss=32.53.ckpt


In [15]:
best_ckpt = checkpoint_callback.best_model_path
preds = trainer.predict(module, ckpt_path=best_ckpt if best_ckpt else "best")

for (logits, angle_gt, dist_gt) in preds:
    if args.task != "multitask":
        # single-task
        save_preds(
            logits, angle_gt,
            f"{args.dataset}_{args.task}_{args.backbone}_{args.concept_features}_{args.n_scenarios}",
            save_path
        )
    else:
        # estrai solo i primi due (angle_preds, dist_preds)
        angle_preds, dist_preds = logits[0], logits[1]

        save_preds(
            angle_preds, angle_gt,
            f"angle_multi_{args.dataset}_{args.task}_{args.backbone}_{args.concept_features}",
            save_path
        )
        save_preds(
            dist_preds, dist_gt,
            f"dist_multi_{args.dataset}_{args.task}_{args.backbone}_{args.concept_features}",
            save_path
        )

print(f"Prediction completate, CSV salvati in: {save_path}")


Predicting: |          | 0/? [00:00<?, ?it/s]

Prediction completate, CSV salvati in: /kaggle/working/ckpts_final_comma_multitask_resnet_True_1.0
