# CONFIGURATION:

In [6]:
from jaad_data import JAAD
import torch
from PIL import Image
from torchvision import transforms
from torchvision import models
import matplotlib.pyplot as plt
import network
import openpose
from openpose import model
from openpose import util
from openpose.body import Body
import copy
from tqdm import tqdm
import pickle
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import torchmetrics
from torcheval.metrics import BinaryAccuracy
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score



In [7]:
JAAD_PATH = '../JAAD'
DEEPLAB_PATH = '../best_deeplabv3plus_resnet101_cityscapes_os16.pth'
SUBSET_PATH = '../subset2'

RESULTS_MASK_SUB = '../masks_results_sub.pkl'
RESULTS_MASK_BIG = '../masks_results_big.pkl'
RESULTS_MASK_BIG_TEST = '../masks_results_big_test.pkl'
RESULTS_MASK_SUB_TEST = '../masks_results_sub_test.pkl'

RESULTS_POSE_BIG = '../pose_results_big.pkl'
RESULTS_POSE_SUB = '../pose_results_sub.pkl'
RESULTS_POSE_BIG_TEST = '../pose_results_big_test.pkl'
RESULTS_POSE_SUB_TEST = '../pose_results_sub_test.pkl'

POSE_PATH = '../body_pose_model.pth'   


In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
torch.cuda.empty_cache()


cuda


In [9]:
# CONFIG (change these params for changing the experiment): 
RUN = False
BIG =False 

In [10]:
if BIG:
    MASK_CMD = RESULTS_MASK_BIG
    POSE_CMD = RESULTS_POSE_BIG
    MASK_CMD_TEST = RESULTS_MASK_BIG_TEST
    POSE_CMD_TEST = RESULTS_POSE_BIG_TEST
    DT_CMD = JAAD_PATH
else:
    MASK_CMD = RESULTS_MASK_SUB
    POSE_CMD = RESULTS_POSE_SUB
    MASK_CMD_TEST = RESULTS_MASK_SUB_TEST
    POSE_CMD_TEST = RESULTS_POSE_SUB_TEST
    DT_CMD = SUBSET_PATH



# DATASET

In [11]:
# Load the JAAD dataset
jaad_dt = JAAD(data_path=DT_CMD)

data_opts = {
    'fstride': 15,
    #'subset': 'high_visibility',
    'sample_type': 'all'
}

seq_train = jaad_dt.generate_data_trajectory_sequence('train', **data_opts)  
seq_test = jaad_dt.generate_data_trajectory_sequence('test', **data_opts)  

---------------------------------------------------------
Generating action sequence data
fstride: 15
sample_type: all
subset: default
height_rng: [0, inf]
squarify_ratio: 0
data_split_type: default
seq_type: intention
min_track_size: 15
random_params: {'ratios': None, 'val_data': True, 'regen_data': False}
kfold_params: {'num_folds': 5, 'fold': 1}
---------------------------------------------------------
Generating database for jaad
jaad database loaded from c:\Users\Filippo\Documents\VSCode\Pedestrian_Intention\subset2\data_cache\jaad_database.pkl
---------------------------------------------------------
Generating intention data
Split: train
Number of pedestrians: 185 
Total number of samples: 60 
---------------------------------------------------------
Generating action sequence data
fstride: 15
sample_type: all
subset: default
height_rng: [0, inf]
squarify_ratio: 0
data_split_type: default
seq_type: intention
min_track_size: 15
random_params: {'ratios': None, 'val_data': True, 'r

In [12]:
print((seq_train['image']))
print((seq_test['image'][1]))
# print(len(seq_train['bbox']))

[['../subset2\\images\\video_0001\\00000.png', '../subset2\\images\\video_0001\\00015.png', '../subset2\\images\\video_0001\\00030.png', '../subset2\\images\\video_0001\\00045.png', '../subset2\\images\\video_0001\\00060.png', '../subset2\\images\\video_0001\\00075.png', '../subset2\\images\\video_0001\\00090.png', '../subset2\\images\\video_0001\\00105.png', '../subset2\\images\\video_0001\\00120.png', '../subset2\\images\\video_0001\\00135.png', '../subset2\\images\\video_0001\\00150.png', '../subset2\\images\\video_0001\\00165.png', '../subset2\\images\\video_0001\\00180.png', '../subset2\\images\\video_0001\\00195.png', '../subset2\\images\\video_0001\\00210.png', '../subset2\\images\\video_0001\\00225.png', '../subset2\\images\\video_0001\\00240.png', '../subset2\\images\\video_0001\\00255.png', '../subset2\\images\\video_0001\\00270.png', '../subset2\\images\\video_0001\\00285.png', '../subset2\\images\\video_0001\\00300.png', '../subset2\\images\\video_0001\\00315.png', '../subs

# GLOBAL CONTEXT EXTRACTION:

In [13]:
if RUN:
    deeplab_model = network.modeling.__dict__['deeplabv3plus_resnet101'](num_classes=19)
    deeplab_model.load_state_dict(torch.load(DEEPLAB_PATH)['model_state'])
    deeplab_model.to(device)
    deeplab_model.eval()

In [14]:
#trasformazioni che vengono usate dentro global context (modifica l'input prima che vada in deeplab)

train_transforms = transforms.Compose([
    transforms.Resize((512, 512)),  # Ridimensiona le immagini a 512x512
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

#processa output deeplab
GC_trans = transforms.Compose([
    transforms.Resize((224, 224)),  # Ridimensiona le immagini a 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [15]:
# Visualizza la maschera semantica
def visualize_mask(image_path, mask):
    image = Image.open(image_path).convert("RGB")
    #image = image.resize((512, 512))  # Ridimensiona per la visualizzazione
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Original Image")
    plt.subplot(1, 2, 2)
    plt.imshow(mask, cmap='jet')
    plt.title("Semantic Mask")
    plt.show()


In [16]:
def get_segmentation_mask(image_path, model, preprocess):
    """ funzione che prende in input le path delle imagini, il modello e la funzione di preprocessamento 
    e restituisce la maschera segmentata dell'immagine resizata a 224x224"""
    
    # Load the image
    input_image = Image.open(image_path).convert("RGB")
    input_tensor = preprocess(input_image).to(device)
    input_batch = input_tensor.unsqueeze(0)  
    
    # Pass the image through the model
    with torch.no_grad():
        output = model(input_batch)
        
    # Check if output is a tensor or a dictionary (auxiliary control)
    if isinstance(output, dict):
        output = output['out'][0]
    elif isinstance(output, torch.Tensor):
        output = output[0]
    else:
        raise ValueError(f"Unexpected output type: {type(output)}")
    
    # Convert the output to a mask
    output_predictions = output.argmax(0)

####################################################################################
    #OPTIONAL FOR DEBUG AND PLOTTING:
    # output_predictions_pic = output_predictions.clone().cpu()
    # visualize_mask(image_path,output_predictions_pic) 
####################################################################################

    # Fix dimensions for the mask and convert to float
    output_predictions = output_predictions.unsqueeze(0).unsqueeze(0).float()

    #convert to RGB and resize to (224,224), because VGG want(3,224,224) as input
    tr = transforms.ToPILImage()
    pic = tr(output_predictions.squeeze(1))
    pic= pic.convert("RGB")
    resized_mask = GC_trans(pic)

    return resized_mask

In [17]:
def process_video_frames(seq_train, model, preprocess):
    """funzione che prende in input la sequenza di training, il modello e la funzione di preprocessamento, restituisce una
    lista di segmentation mask per ogni frame di ogni video della sequenza di training"""
    
    all_masks = []
    for video_frames in tqdm(seq_train['image'], desc="Processing videos"):
        video_masks = []

        for frame_path in tqdm(video_frames, desc="Processing frames", leave=False):
            mask = get_segmentation_mask(frame_path, model, preprocess)
            video_masks.append(mask)
        all_masks.append(video_masks)
    
    return all_masks


In [18]:
if RUN:
    all_video_masks = process_video_frames(seq_train, deeplab_model, train_transforms)
    seq_train['masks'] = all_video_masks
    all_video_masks_test = process_video_frames(seq_test, deeplab_model, train_transforms)
    seq_test['masks'] = all_video_masks_test

    #cleaning for my poor gpu
    del deeplab_model
    torch.cuda.empty_cache()

    # save data in the .pkl files 
    with open(MASK_CMD, 'wb') as f:
        pickle.dump(seq_train['masks'], f)
    with open(MASK_CMD_TEST, 'wb') as f:
        pickle.dump(seq_test['masks'], f)
else:
    #recover data:
    with open(MASK_CMD, 'rb') as f:
        seq_train['masks'] = pickle.load(f)
    with open(MASK_CMD_TEST, 'rb') as f:
        seq_test['masks'] = pickle.load(f)


# LOCAL CONTEXT:

In [19]:
# Trasformazioni per le immaginin che vengono per il local context 
transform_lc = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [20]:
def crop_image_cv2(img, bbox):
    """ funzione che croppa i frames sul bounding boxes, le imagini sono in formato cv2"""
    
    x1, y1, x2, y2 = bbox
    return img[int(y1):int(y2), int(x1):int(x2)]

In [21]:
"""Trasformation for the local context's images, enhance the quality of the images by appling gaussian filter, unsharp mask e bilateral filter"""
all_images = [] #list of the images
for i in tqdm(range(len(seq_train['image'])), desc="Processing videos"):
    aux_list = []
    for j in tqdm(range(len(seq_train['image'][i])), desc="Processing frames", leave=False):
        
        # Open the images from the paths
        img_path = seq_train['image'][i][j]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB
        
        #get the bbox and crop arround it
        bbox = seq_train['bbox'][i][j]
        cropped_images = crop_image_cv2(img, bbox)
        
        # Resize to 224x224 (VGG input size)
        final_image = cv2.resize(cropped_images, (224, 224))
        aux_list.append(final_image)
    all_images.append(aux_list)

Processing videos: 100%|██████████| 60/60 [00:33<00:00,  1.79it/s]


In [22]:
#same of above, but for the test set
all_images_test = []
for i in tqdm(range(len(seq_test['image'])), desc="Processing videos"):
    aux_list = []
    for j in tqdm(range(len(seq_test['image'][i])), desc="Processing frames", leave=False):
        
        # Open the images from the paths
        img_path = seq_test['image'][i][j]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB

        #get the bbox and crop arround it
        bbox = seq_test['bbox'][i][j]
        cropped_images = crop_image_cv2(img, bbox)
        
        # Resize to 224x224 (VGG input size)
        final_image = cv2.resize(cropped_images, (224, 224))
        aux_list.append(final_image)
    all_images_test.append(aux_list)


Processing videos: 100%|██████████| 97/97 [00:54<00:00,  1.78it/s]


# POSE KEYPOINTS: 

In [23]:
def extract_pose_sequence(frames, body_model):
    """funzione che prende in input i frames e il modello di openpose e restituisce una lista di tensori di pose 
    per ciascuna persona nel tempo, se non prende nessuna posa mette un placeholder"""
    #print(frames)
    pose_sequences = []  # Lista di pose per ciascuna persona nel tempo
    pose_placeholder = torch.zeros((36,), dtype=torch.float32)
    for frame in frames:
        candidate, subset = body_model(frame)

        ###############################################################
        # DEBUG & PLOTTING: Uncomment the following
        canvas = copy.deepcopy(frame)
        canvas = util.draw_bodypose(canvas, candidate, subset)
        plt.imshow(canvas[:, :, [2, 1, 0]])  
        plt.axis('off')
        plt.show()
        ###############################################################

        frame_poses = []
        for person in subset:
            if person[-1] >= 4:  # Almeno 4 punti chiave rilevati
                pose = []
                for i in range(18):
                    if person[i] != -1:
                        x, y = candidate[int(person[i])][:2]
                    else:
                        x, y = -1, -1  # Punti chiave mancanti
                    pose.extend([x, y])
                frame_poses.append(pose)
        if not frame_poses:
            frame_poses = [pose_placeholder.tolist()] #use placeholder if no pose detected (for dimensional consistency)
        pose_sequences.append(frame_poses)

    # Trasponi la lista di liste per ottenere le sequenze temporali per ciascuna persona
    person_pose_sequences = list(map(list, zip(*pose_sequences)))
    person_pose_sequences = [torch.tensor(person_poses, dtype=torch.float32) for person_poses in person_pose_sequences]
    
    return person_pose_sequences

In [24]:

if RUN:
    body_model = Body(POSE_PATH)

    # Caricamento dei frame e estrazione delle pose
    all_poses = []
    all_poses_test = []
    # print(len(all_images))
    # print(len(all_images[0]))
    # print(len(all_images[0][0]))
    # print(all_images[0][0].shape)
    #itera tra i video e prendi i frames
    for pics in tqdm(all_images, desc="Extracting poses from image sequences"):
        #print(len(pics))
        pose_sequences = extract_pose_sequence(pics, body_model)
        all_poses.append(pose_sequences)
        #print(pose_sequences.shape)
    # Aggiungi le pose estratte alla sequenza di allenamento
    seq_train['poses'] = all_poses

    for pics in tqdm(all_images_test, desc="Extracting poses from image sequences of test set"):
        pose_sequences = extract_pose_sequence(pics, body_model)
        all_poses_test.append(pose_sequences)
    # Aggiungi le pose estratte alla sequenza di allenamento
    seq_test['poses'] = all_poses_test
    del body_model
    torch.cuda.empty_cache()
    # Apri il file in modalità scrittura binaria e salva il dizionario

    with open(POSE_CMD, 'wb') as f:
        pickle.dump(seq_train['poses'], f)
    with open(POSE_CMD_TEST, 'wb') as f:
        pickle.dump(seq_test['poses'], f)
else:
    #recover data:
    with open(POSE_CMD, 'rb') as f:
        seq_train['poses'] = pickle.load(f)
    with open(POSE_CMD_TEST, 'rb') as f:
        seq_test['poses'] = pickle.load(f)
    # Verifica che i risultati siano stati caricati correttamente
    #print(seq_train['masks'])

# MODEL:

In [25]:
class VisionBranchLocal(torch.nn.Module):
    """definizione del modello per il local context, prende in input le immagini croppate e restituisce un tensore,
    le immagini croppate vengono fatte passare dentro una VGG16, una GRU e un attention block"""

    def __init__(self, vgg16):
        super(VisionBranchLocal, self).__init__()
        self.vgg16 = vgg16
        self.avgpool = torch.nn.AvgPool2d(kernel_size=14)  # Pooling layer con kernel 14x14
        self.gru = torch.nn.GRU(input_size=512, hidden_size=256, num_layers=2, batch_first=True)
        self.fc = torch.nn.Linear(256, 2)    # Fully connected layer
        self.attn = torch.nn.Linear(256, 1)  # Attention layer
        self.tanh = torch.nn.Sigmoid()

    def forward(self, cropped_images):
        seq_len, c, h, w = cropped_images.size()
        
        # Estrai feature dalle immagini con VGG16
        vgg_features = []
        for i in range(seq_len):            

            img = cropped_images[i]            
            vgg_feat_img = self.vgg16.features(img)
            pooled_feat_img = self.avgpool(vgg_feat_img)  # Applica il pooling
            vgg_feat_img = pooled_feat_img.view(pooled_feat_img.size(0), -1)  # Flatten features
            vgg_features.append(vgg_feat_img)
        
        vgg_features = torch.stack(vgg_features, dim=1).permute(2,1,0)

        #applica la gru
        gru_out, _ = self.gru(vgg_features)

        #applica l'attention
        attn_weights = torch.softmax(self.attn(gru_out), dim=1)
        context_vector = torch.sum(attn_weights * gru_out, dim=1)
        
       # print("SIZE vgg features:",vgg_features.shape)
       # print("SIZE context v local context:",context_vector.shape)
        #out = self.sigmoid(self.fc(gru_out[:, -1, :]))   
        out = self.tanh((context_vector))
        return out


In [26]:
class VisionBranchGlobal(torch.nn.Module):
    """definizione del modello per il global context, prende in input le maskere semantiche e restituisce un tensore,
    le maskere semantiche vengono fatte passare dentro una VGG16, una GRU e un attention block"""

    def __init__(self, vgg16):
        super(VisionBranchGlobal, self).__init__()
        self.vgg16 = vgg16
        self.avgpool = torch.nn.AvgPool2d(kernel_size=14)  # Pooling layer con kernel 14x14
        self.gru = torch.nn.GRU(input_size=512, hidden_size=256, num_layers=2, batch_first=True)
        self.fc = torch.nn.Linear(256, 2)    # Fully connected layer
        self.attn = torch.nn.Linear(256, 1)  # Attention layer
        self.tanh = torch.nn.Tanh()

    def forward(self, masks):
        seq_len = masks.size()[0]
        #print("size forward:",seq_len)
        # Estrai feature dalle immagini con VGG16
        vgg_features = []
        for i in range(seq_len):            
            img = masks[i]            
            vgg_feat_img = self.vgg16.features(img)
            pooled_feat_img = self.avgpool(vgg_feat_img)  # Applica il pooling
            vgg_feat_img = pooled_feat_img.view(pooled_feat_img.size(0), -1)  # Flatten features
            vgg_features.append(vgg_feat_img)
        
        vgg_features = torch.stack(vgg_features, dim=1).permute(2,1,0)

        gru_out, _ = self.gru(vgg_features)
        attn_scores = self.attn(gru_out)  # shape: (batch_size, seq_length, 1)
       # print("SIZE attention scores GLO:", attn_scores.shape)
        
        attn_weights = torch.softmax(attn_scores, dim=1)  # shape: (batch_size, seq_length, 1)
        #print("SIZE attention weights GLO:", attn_weights.shape)
        
        context_vector = torch.sum(attn_weights * gru_out, dim=1)  # shape: (batch_size, 256)
        #print("SIZE context vector GLO:", context_vector.shape)
        
        out = self.tanh((context_vector))
        return out


In [27]:
class NVisionBranch(torch.nn.Module):
    """classe relativa al non-vision brach, prende in input le pose e le bbox in formato tensore, esse vengono fatte passare
    dentro una GRU e un attention block, l'ordine influenza la prestazioni"""
    
    def __init__(self):
        super(NVisionBranch, self).__init__()
        self.gru = torch.nn.GRU(input_size=36, hidden_size=256, num_layers=2, batch_first=True)
        self.gru2 = torch.nn.GRU(input_size=256+4, hidden_size=256, num_layers=2, batch_first=True)
        #self.fc = torch.nn.Linear(256, 2)  
        self.attn = torch.nn.Linear(256, 1)  # Attention layer
        self.tanh = torch.nn.Tanh()

    def forward(self, poses,bbox):
        gru_out, _ = self.gru(poses)
        #print("SIZE outuyput gru posa:",gru_out.shape)
        #print("bbox:",bbox.shape)
        LP = torch.cat((gru_out,bbox),dim=-1)
        #print("SIZE outuyput gru posa + bbox:",LP.shape)
        gru_out, _ = self.gru2(LP)
       # print("SIZE output gru:",gru_out.shape)

        # Attention mechanism
        #features = torch.stack([gru_out[:,i,:] for i in range(gru_out.size(1))], dim=1)
        attn_scores = self.attn(gru_out)  # shape: (batch_size, seq_length, 1)
       # print("SIZE attention scores:", attn_scores.shape)
        
        attn_weights = torch.softmax(attn_scores, dim=1)  # shape: (batch_size, seq_length, 1)
        #print("SIZE attention weights:", attn_weights.shape)
        
        context_vector = torch.sum(attn_weights * gru_out, dim=1)  # shape: (batch_size, 256)
        #print("SIZE context vector:", context_vector.shape)
        
        #out = self.sigmoid(self.fc(gru_out[:, -1, :]))   
        out = self.tanh(context_vector)
        return out


In [28]:
class PedestrianIntentModel(torch.nn.Module):
    """definizione del modello finale, prende in input gli output del vision brach e del non vision branch, viene fatta
    una concatenazione che in seguito passa dentro un attention e un fully connected layer, l'output è la predizione,
    (passa non passa)"""

    def __init__(self, vision_branch_local,vision_branch_global,non_vision_branch):
        super(PedestrianIntentModel, self).__init__()
        self.vision_branch_local = vision_branch_local #output of the vision branch local
        self.vision_branch_global = vision_branch_global #output of the vision branch global
        self.non_vision_branch = non_vision_branch #output of the non vision branch
        self.attn = torch.nn.Linear(768, 768)  # Attention layer

        self.fc1 = torch.nn.Linear(768, 256) # fully connected layer
        self.fc2 = torch.nn.Linear(256,1) # Output: crossing or not crossing

        
    def forward(self, cropped_images, bboxes, masks, poses):
        vision_out_local = self.vision_branch_local(cropped_images)
        vision_out_global = self.vision_branch_global(masks)
        non_vision_out = self.non_vision_branch(poses, bboxes)

        vision_out = torch.cat((vision_out_local, vision_out_global), dim=-1)
        final_fusion = torch.cat((vision_out, non_vision_out), dim=-1)

        attn_scores = self.attn(final_fusion)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        context_vector = torch.sum(attn_weights * final_fusion, dim=0)

        out = (self.fc2(self.fc1(context_vector))) #raw output

        return out


In [29]:
class VGG16_FeatureExtractor(torch.nn.Module):

    def __init__(self):
        super(VGG16_FeatureExtractor, self).__init__()
        self.features = torch.nn.Sequential(*list(vgg16.features.children())[:24]) # block4_pool è il 24° livello
    def forward(self, x):
        x = self.features(x)
        return x

In [30]:
# Carica il modello VGG19 pre-addestrato
vgg16 = models.vgg16(pretrained=True)

#cut the model at the 24th layer:
vgg16_fe = VGG16_FeatureExtractor()
vgg16_fe


# define the models of each branches
model_local = VisionBranchLocal(vgg16_fe).to(device)
model_global = VisionBranchGlobal(vgg16_fe).to(device)
model_non_vision = NVisionBranch().to(device)
model = PedestrianIntentModel(model_local,model_global,model_non_vision).to(device)




# DATASET & DATALOADER

In [31]:
class JAADDataset(Dataset):
    """definizione della classe per il custom dataset, prende in input la seq_train, le immagini e le trasformazioni,
    restituisce il tensore delle immagini croppate, le bboxes, le maschere, le pose e la lables"""

    def __init__(self, seq_data, all_images, transform=None):
        self.seq_data = seq_data
        self.all_images = all_images
        self.transform = transform

    def __len__(self):
        return len(self.seq_data['image'])

    def __getitem__(self, idx):
        bbox_sequence = self.seq_data['bbox'][idx]
        masks = self.seq_data['masks'][idx]
        poses = self.seq_data['poses'][idx]
        all_images = self.all_images[idx]

        if self.transform:
            tensor_images = [self.transform(img) for img in all_images]

        bboxes = torch.tensor(self.seq_data['bbox'][idx], dtype=torch.float32)
        intents = torch.tensor(self.seq_data['intent'][idx], dtype=torch.float32)
        return  tensor_images, bboxes, masks, poses, intents

In [32]:
train_dataset = JAADDataset(seq_train,all_images=all_images, transform=transform_lc)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_dataset = JAADDataset(seq_test,all_images=all_images_test, transform=transform_lc)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# TRAINING:

In [33]:
criterion = torch.nn.BCEWithLogitsLoss() #input are raw output
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001) 

In [34]:
def evaluate_metrics(net, loader, device):
    acc = BinaryAccuracy().to(device)
    precision = BinaryPrecision().to(device)
    recall = BinaryRecall().to(device)
    f1_score = BinaryF1Score().to(device)
    
    model.eval()

    for tensor_images, bboxes, masks, poses, intents in loader:
        poses = torch.stack(poses, dim=0)  
        poses = poses.squeeze(0)
        poses = poses.view(len(tensor_images), -1, 36).permute(1, 0, 2) 
        
        # Move tensors to device
        tensor_images = torch.stack(tensor_images, dim=1).squeeze(0).permute(0, 1, 2, 3).to(device)  # Convert image list to tensor
        masks = torch.stack(masks, dim=1).squeeze(0).float().to(device)  # Convert mask list to tensor
        bboxes = bboxes.to(device)
        poses = poses.to(device)
        intents = intents.squeeze(0)[0].to(device)
        
        ypred = net(tensor_images, bboxes, masks, poses)

        # Move tensors back to CPU and clean up memory
        tensor_images.cpu()
        masks.cpu()
        bboxes.cpu()
        poses.cpu()
        del tensor_images, bboxes, masks, poses
        torch.cuda.empty_cache()
        
        # Update metrics
        acc.update(ypred, intents)
        precision.update(ypred, intents)
        recall.update(ypred, intents)
        f1_score.update(ypred, intents)
        
        intents.cpu()
        del intents
        torch.cuda.empty_cache()
    
    metrics = {
        'accuracy': acc.compute().item(),
        'precision': precision.compute().item(),
        'recall': recall.compute().item(),
        'f1_score': f1_score.compute().item()
    }
    
    # Reset metrics for the next evaluation
    acc.reset()
    precision.reset()
    recall.reset()
    f1_score.reset()
    
    return metrics


In [35]:

num_epochs = 100

# free the memory
torch.cuda.empty_cache()

# Placeholder tensor for empty poses
pose_placeholder = torch.zeros((36,), dtype=torch.float32) 

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    model.train()

    for tensor_images, bboxes, masks, poses, intents in tqdm(train_loader):
        #case when there are no poses in the frame 
        # if len(poses) == 0:
        #     print("Empty poses detected, adding placeholder tensor.")
        #     poses = [pose_placeholder for _ in range(len(tensor_images))]
        #     poses = torch.stack(poses, dim=0)  # Now the shape is (batch_size, 36)
        #     #poses = poses.squeeze(0)
        #     poses = poses.view(len(tensor_images), -1, 36).permute(1,0,2)  # Reshape to (batch_size, numeroFrames, 36)
        # else:
        #     #poses = [torch.tensor(p, dtype=torch.float32) for p in poses]
        #     poses = torch.stack(poses, dim=0)  # Now the shape is (batch_size, 36)
        #     poses = poses.squeeze(0)
        # poses = [pose_placeholder for _ in range(len(tensor_images))]
        # poses = torch.stack(poses, dim=0)  # Now the shape is (batch_size, 36)
        poses = torch.stack(poses, dim=0)  # Now the shape is (batch_size, 36)
        poses = poses.squeeze(0)
        #poses = poses.squeeze(0)
        poses = poses.view(len(tensor_images), -1, 36).permute(1,0,2)  # Reshape to (batch_size, numeroFrames, 36)
        #print("TRAINING:")
        #print("poses",len(poses))
        # Convert poses to tensor and reshape
        #print("poses shape:",poses.shape)   

        # Move tensors to device
        tensor_images = torch.stack(tensor_images, dim=1).squeeze(0).permute(0, 1, 2,3).to(device)  # Converte la lista di immagini in un tensor
        masks = torch.stack(masks,dim=1).squeeze(0).float().to(device)  # Converte la lista di maschere in un tensor
        bboxes = bboxes.to(device)
        poses = poses.to(device)
        intents = intents.squeeze(0)[0].to(device)
        # print("ti",tensor_images.shape)
        # print("BBOX:", bboxes.shape)
        # print("MASKS:", masks.shape)
        # print("POSES:", poses.shape)
        # print("INTENTS:", intents.shape)

        optimizer.zero_grad()

        outputs = model(tensor_images, bboxes, masks, poses)
        tensor_images.cpu()
        #print("TRAIN out vs intents",outputs,intents)
        masks.cpu()
        bboxes.cpu()
        poses.cpu()
        #print(outputs)
        #outputs = torch.argmax((outputs))
        loss = criterion(outputs, intents.float())
        loss.backward()
        #for name, param in model.named_parameters():
            #if param.grad is not None:
                #print(f"Gradiente di {name}: {param.grad.norm()}")

        optimizer.step()
        del tensor_images, masks, bboxes, poses, intents, outputs
        torch.cuda.empty_cache()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    with torch.no_grad():
        print(f'Accuracy at epoch {epoch}: {evaluate_metrics(model, test_loader, device)}')
        



Epoch 1/100


100%|██████████| 60/60 [00:51<00:00,  1.17it/s]


Epoch [1/100], Loss: 0.7873
Accuracy at epoch 0: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 2/100


100%|██████████| 60/60 [00:49<00:00,  1.21it/s]


Epoch [2/100], Loss: 0.8152
Accuracy at epoch 1: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 3/100


100%|██████████| 60/60 [00:43<00:00,  1.37it/s]


Epoch [3/100], Loss: 0.7681
Accuracy at epoch 2: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 4/100


100%|██████████| 60/60 [00:46<00:00,  1.30it/s]


Epoch [4/100], Loss: 0.5769
Accuracy at epoch 3: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 5/100


100%|██████████| 60/60 [00:44<00:00,  1.35it/s]


Epoch [5/100], Loss: 0.5858
Accuracy at epoch 4: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 6/100


100%|██████████| 60/60 [00:44<00:00,  1.36it/s]


Epoch [6/100], Loss: 0.4555
Accuracy at epoch 5: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 7/100


100%|██████████| 60/60 [00:45<00:00,  1.31it/s]


Epoch [7/100], Loss: 0.3512
Accuracy at epoch 6: {'accuracy': 0.8041236996650696, 'precision': 0.5384615659713745, 'recall': 0.3499999940395355, 'f1_score': 0.42424243688583374}
Epoch 8/100


100%|██████████| 60/60 [00:49<00:00,  1.22it/s]


Epoch [8/100], Loss: 0.2366
Accuracy at epoch 7: {'accuracy': 0.8247422575950623, 'precision': 0.5517241358757019, 'recall': 0.800000011920929, 'f1_score': 0.6530612111091614}
Epoch 9/100


100%|██████████| 60/60 [00:46<00:00,  1.29it/s]


Epoch [9/100], Loss: 0.3022
Accuracy at epoch 8: {'accuracy': 0.7525773048400879, 'precision': 0.4444444477558136, 'recall': 0.800000011920929, 'f1_score': 0.5714285969734192}
Epoch 10/100


100%|██████████| 60/60 [00:47<00:00,  1.27it/s]


Epoch [10/100], Loss: 0.3525
Accuracy at epoch 9: {'accuracy': 0.7319587469100952, 'precision': 0.4285714328289032, 'recall': 0.8999999761581421, 'f1_score': 0.5806451439857483}
Epoch 11/100


100%|██████████| 60/60 [00:48<00:00,  1.24it/s]


Epoch [11/100], Loss: 0.2873
Accuracy at epoch 10: {'accuracy': 0.7835051417350769, 'precision': 0.4864864945411682, 'recall': 0.8999999761581421, 'f1_score': 0.6315789222717285}
Epoch 12/100


100%|██████████| 60/60 [00:47<00:00,  1.27it/s]


Epoch [12/100], Loss: 0.2601
Accuracy at epoch 11: {'accuracy': 0.8041236996650696, 'precision': 0.52173912525177, 'recall': 0.6000000238418579, 'f1_score': 0.5581395626068115}
Epoch 13/100


100%|██████████| 60/60 [00:44<00:00,  1.34it/s]


Epoch [13/100], Loss: 0.3172
Accuracy at epoch 12: {'accuracy': 0.8453608155250549, 'precision': 0.6190476417541504, 'recall': 0.6499999761581421, 'f1_score': 0.6341463327407837}
Epoch 14/100


100%|██████████| 60/60 [00:43<00:00,  1.37it/s]


Epoch [14/100], Loss: 0.2422
Accuracy at epoch 13: {'accuracy': 0.7835051417350769, 'precision': 0.4864864945411682, 'recall': 0.8999999761581421, 'f1_score': 0.6315789222717285}
Epoch 15/100


100%|██████████| 60/60 [00:43<00:00,  1.37it/s]


Epoch [15/100], Loss: 0.6211
Accuracy at epoch 14: {'accuracy': 0.8453608155250549, 'precision': 0.6315789222717285, 'recall': 0.6000000238418579, 'f1_score': 0.6153846383094788}
Epoch 16/100


100%|██████████| 60/60 [00:43<00:00,  1.38it/s]


Epoch [16/100], Loss: 0.5774
Accuracy at epoch 15: {'accuracy': 0.8556700944900513, 'precision': 0.6363636255264282, 'recall': 0.699999988079071, 'f1_score': 0.6666666865348816}
Epoch 17/100


100%|██████████| 60/60 [00:45<00:00,  1.32it/s]


Epoch [17/100], Loss: 1.5249
Accuracy at epoch 16: {'accuracy': 0.8556700944900513, 'precision': 0.6071428656578064, 'recall': 0.8500000238418579, 'f1_score': 0.7083333134651184}
Epoch 18/100


100%|██████████| 60/60 [00:46<00:00,  1.30it/s]


Epoch [18/100], Loss: 0.4481
Accuracy at epoch 17: {'accuracy': 0.7525773048400879, 'precision': 0.44999998807907104, 'recall': 0.8999999761581421, 'f1_score': 0.6000000238418579}
Epoch 19/100


100%|██████████| 60/60 [00:43<00:00,  1.39it/s]


Epoch [19/100], Loss: 0.1984
Accuracy at epoch 18: {'accuracy': 0.8041236996650696, 'precision': 0.5161290168762207, 'recall': 0.800000011920929, 'f1_score': 0.6274510025978088}
Epoch 20/100


100%|██████████| 60/60 [00:43<00:00,  1.38it/s]


Epoch [20/100], Loss: 0.2851
Accuracy at epoch 19: {'accuracy': 0.7835051417350769, 'precision': 0.4838709533214569, 'recall': 0.75, 'f1_score': 0.5882353186607361}
Epoch 21/100


100%|██████████| 60/60 [00:45<00:00,  1.33it/s]


Epoch [21/100], Loss: 1.6654
Accuracy at epoch 20: {'accuracy': 0.6804123520851135, 'precision': 0.3777777850627899, 'recall': 0.8500000238418579, 'f1_score': 0.5230769515037537}
Epoch 22/100


100%|██████████| 60/60 [00:43<00:00,  1.38it/s]


Epoch [22/100], Loss: 0.2151
Accuracy at epoch 21: {'accuracy': 0.8247422575950623, 'precision': 0.5454545617103577, 'recall': 0.8999999761581421, 'f1_score': 0.6792452931404114}
Epoch 23/100


100%|██████████| 60/60 [00:42<00:00,  1.40it/s]


Epoch [23/100], Loss: 0.2506
Accuracy at epoch 22: {'accuracy': 0.8247422575950623, 'precision': 0.5483871102333069, 'recall': 0.8500000238418579, 'f1_score': 0.6666666865348816}
Epoch 24/100


100%|██████████| 60/60 [00:42<00:00,  1.42it/s]


Epoch [24/100], Loss: 0.4728
Accuracy at epoch 23: {'accuracy': 0.8453608155250549, 'precision': 0.692307710647583, 'recall': 0.44999998807907104, 'f1_score': 0.5454545617103577}
Epoch 25/100


100%|██████████| 60/60 [00:42<00:00,  1.41it/s]


Epoch [25/100], Loss: 1.3689
Accuracy at epoch 24: {'accuracy': 0.8041236996650696, 'precision': 1.0, 'recall': 0.05000000074505806, 'f1_score': 0.095238097012043}
Epoch 26/100


100%|██████████| 60/60 [00:42<00:00,  1.42it/s]


Epoch [26/100], Loss: 0.3466
Accuracy at epoch 25: {'accuracy': 0.8659793734550476, 'precision': 0.7333333492279053, 'recall': 0.550000011920929, 'f1_score': 0.6285714507102966}
Epoch 27/100


100%|██████████| 60/60 [00:42<00:00,  1.43it/s]


Epoch [27/100], Loss: 1.3658
Accuracy at epoch 26: {'accuracy': 0.6082473993301392, 'precision': 0.3333333432674408, 'recall': 0.8999999761581421, 'f1_score': 0.4864864945411682}
Epoch 28/100


100%|██████████| 60/60 [00:42<00:00,  1.43it/s]


Epoch [28/100], Loss: 1.4587
Accuracy at epoch 27: {'accuracy': 0.8350515365600586, 'precision': 0.8333333134651184, 'recall': 0.25, 'f1_score': 0.38461539149284363}
Epoch 29/100


100%|██████████| 60/60 [00:41<00:00,  1.44it/s]


Epoch [29/100], Loss: 0.3674
Accuracy at epoch 28: {'accuracy': 0.8556700944900513, 'precision': 0.7142857313156128, 'recall': 0.5, 'f1_score': 0.5882353186607361}
Epoch 30/100


100%|██████████| 60/60 [00:41<00:00,  1.45it/s]


Epoch [30/100], Loss: 0.9154
Accuracy at epoch 29: {'accuracy': 0.8247422575950623, 'precision': 0.5882353186607361, 'recall': 0.5, 'f1_score': 0.5405405163764954}
Epoch 31/100


100%|██████████| 60/60 [00:41<00:00,  1.44it/s]


Epoch [31/100], Loss: 1.0470
Accuracy at epoch 30: {'accuracy': 0.7835051417350769, 'precision': 0.47826087474823, 'recall': 0.550000011920929, 'f1_score': 0.5116279125213623}
Epoch 32/100


100%|██████████| 60/60 [00:41<00:00,  1.44it/s]


Epoch [32/100], Loss: 0.3210
Accuracy at epoch 31: {'accuracy': 0.8247422575950623, 'precision': 0.5789473652839661, 'recall': 0.550000011920929, 'f1_score': 0.5641025900840759}
Epoch 33/100


100%|██████████| 60/60 [00:41<00:00,  1.44it/s]


Epoch [33/100], Loss: 1.6607
Accuracy at epoch 32: {'accuracy': 0.8144329786300659, 'precision': 0.5714285969734192, 'recall': 0.4000000059604645, 'f1_score': 0.47058823704719543}
Epoch 34/100


100%|██████████| 60/60 [00:42<00:00,  1.41it/s]


Epoch [34/100], Loss: 0.4748
Accuracy at epoch 33: {'accuracy': 0.7938144207000732, 'precision': 0.5, 'recall': 0.6000000238418579, 'f1_score': 0.5454545617103577}
Epoch 35/100


100%|██████████| 60/60 [00:41<00:00,  1.45it/s]


Epoch [35/100], Loss: 0.3933
Accuracy at epoch 34: {'accuracy': 0.8144329786300659, 'precision': 0.5555555820465088, 'recall': 0.5, 'f1_score': 0.5263158082962036}
Epoch 36/100


100%|██████████| 60/60 [00:41<00:00,  1.44it/s]


Epoch [36/100], Loss: 0.2487
Accuracy at epoch 35: {'accuracy': 0.8247422575950623, 'precision': 0.5789473652839661, 'recall': 0.550000011920929, 'f1_score': 0.5641025900840759}
Epoch 37/100


100%|██████████| 60/60 [00:41<00:00,  1.45it/s]


Epoch [37/100], Loss: 0.2993
Accuracy at epoch 36: {'accuracy': 0.8041236996650696, 'precision': 0.52173912525177, 'recall': 0.6000000238418579, 'f1_score': 0.5581395626068115}
Epoch 38/100


100%|██████████| 60/60 [00:41<00:00,  1.46it/s]


Epoch [38/100], Loss: 0.1277
Accuracy at epoch 37: {'accuracy': 0.6907216310501099, 'precision': 0.3863636255264282, 'recall': 0.8500000238418579, 'f1_score': 0.53125}
Epoch 39/100


100%|██████████| 60/60 [00:41<00:00,  1.44it/s]


Epoch [39/100], Loss: 1.4723
Accuracy at epoch 38: {'accuracy': 0.7938144207000732, 'precision': 0.5, 'recall': 0.6499999761581421, 'f1_score': 0.5652173757553101}
Epoch 40/100


100%|██████████| 60/60 [00:41<00:00,  1.46it/s]


Epoch [40/100], Loss: 0.4461
Accuracy at epoch 39: {'accuracy': 0.8453608155250549, 'precision': 0.6315789222717285, 'recall': 0.6000000238418579, 'f1_score': 0.6153846383094788}
Epoch 41/100


100%|██████████| 60/60 [00:41<00:00,  1.45it/s]


Epoch [41/100], Loss: 1.1612
Accuracy at epoch 40: {'accuracy': 0.7938144207000732, 'precision': 0.5, 'recall': 0.75, 'f1_score': 0.6000000238418579}
Epoch 42/100


100%|██████████| 60/60 [00:41<00:00,  1.45it/s]


Epoch [42/100], Loss: 0.8844
Accuracy at epoch 41: {'accuracy': 0.7319587469100952, 'precision': 0.38461539149284363, 'recall': 0.5, 'f1_score': 0.43478259444236755}
Epoch 43/100


100%|██████████| 60/60 [00:41<00:00,  1.45it/s]


Epoch [43/100], Loss: 0.4063
Accuracy at epoch 42: {'accuracy': 0.8453608155250549, 'precision': 0.6666666865348816, 'recall': 0.5, 'f1_score': 0.5714285969734192}
Epoch 44/100


100%|██████████| 60/60 [00:41<00:00,  1.45it/s]


Epoch [44/100], Loss: 0.1340
Accuracy at epoch 43: {'accuracy': 0.8350515365600586, 'precision': 0.625, 'recall': 0.5, 'f1_score': 0.5555555820465088}
Epoch 45/100


100%|██████████| 60/60 [00:41<00:00,  1.45it/s]


Epoch [45/100], Loss: 0.9983
Accuracy at epoch 44: {'accuracy': 0.8453608155250549, 'precision': 0.6470588445663452, 'recall': 0.550000011920929, 'f1_score': 0.5945945978164673}
Epoch 46/100


100%|██████████| 60/60 [00:41<00:00,  1.46it/s]


Epoch [46/100], Loss: 0.2831
Accuracy at epoch 45: {'accuracy': 0.8041236996650696, 'precision': 0.52173912525177, 'recall': 0.6000000238418579, 'f1_score': 0.5581395626068115}
Epoch 47/100


100%|██████████| 60/60 [00:39<00:00,  1.52it/s]


Epoch [47/100], Loss: 0.9437
Accuracy at epoch 46: {'accuracy': 0.8247422575950623, 'precision': 0.5714285969734192, 'recall': 0.6000000238418579, 'f1_score': 0.5853658318519592}
Epoch 48/100


100%|██████████| 60/60 [00:39<00:00,  1.50it/s]


Epoch [48/100], Loss: 0.1885
Accuracy at epoch 47: {'accuracy': 0.8041236996650696, 'precision': 0.5185185074806213, 'recall': 0.699999988079071, 'f1_score': 0.5957446694374084}
Epoch 49/100


100%|██████████| 60/60 [00:39<00:00,  1.52it/s]


Epoch [49/100], Loss: 0.3126
Accuracy at epoch 48: {'accuracy': 0.8041236996650696, 'precision': 0.5161290168762207, 'recall': 0.800000011920929, 'f1_score': 0.6274510025978088}
Epoch 50/100


100%|██████████| 60/60 [00:39<00:00,  1.52it/s]


Epoch [50/100], Loss: 0.8947
Accuracy at epoch 49: {'accuracy': 0.8556700944900513, 'precision': 0.6875, 'recall': 0.550000011920929, 'f1_score': 0.6111111044883728}
Epoch 51/100


100%|██████████| 60/60 [00:42<00:00,  1.42it/s]


Epoch [51/100], Loss: 0.3324
Accuracy at epoch 50: {'accuracy': 0.7835051417350769, 'precision': 0.47826087474823, 'recall': 0.550000011920929, 'f1_score': 0.5116279125213623}
Epoch 52/100


100%|██████████| 60/60 [00:42<00:00,  1.43it/s]


Epoch [52/100], Loss: 0.0354
Accuracy at epoch 51: {'accuracy': 0.7525773048400879, 'precision': 0.44999998807907104, 'recall': 0.8999999761581421, 'f1_score': 0.6000000238418579}
Epoch 53/100


100%|██████████| 60/60 [00:40<00:00,  1.47it/s]


Epoch [53/100], Loss: 1.7011
Accuracy at epoch 52: {'accuracy': 0.8453608155250549, 'precision': 0.692307710647583, 'recall': 0.44999998807907104, 'f1_score': 0.5454545617103577}
Epoch 54/100


100%|██████████| 60/60 [00:41<00:00,  1.46it/s]


Epoch [54/100], Loss: 1.4700
Accuracy at epoch 53: {'accuracy': 0.8453608155250549, 'precision': 0.692307710647583, 'recall': 0.44999998807907104, 'f1_score': 0.5454545617103577}
Epoch 55/100


100%|██████████| 60/60 [00:40<00:00,  1.49it/s]


Epoch [55/100], Loss: 0.1105
Accuracy at epoch 54: {'accuracy': 0.8247422575950623, 'precision': 0.5789473652839661, 'recall': 0.550000011920929, 'f1_score': 0.5641025900840759}
Epoch 56/100


100%|██████████| 60/60 [00:41<00:00,  1.46it/s]


Epoch [56/100], Loss: 1.1755
Accuracy at epoch 55: {'accuracy': 0.3814432919025421, 'precision': 0.25, 'recall': 1.0, 'f1_score': 0.4000000059604645}
Epoch 57/100


100%|██████████| 60/60 [00:41<00:00,  1.46it/s]


Epoch [57/100], Loss: 0.2135
Accuracy at epoch 56: {'accuracy': 0.8350515365600586, 'precision': 0.6666666865348816, 'recall': 0.4000000059604645, 'f1_score': 0.5}
Epoch 58/100


100%|██████████| 60/60 [00:40<00:00,  1.47it/s]


Epoch [58/100], Loss: 0.0767
Accuracy at epoch 57: {'accuracy': 0.8556700944900513, 'precision': 0.6875, 'recall': 0.550000011920929, 'f1_score': 0.6111111044883728}
Epoch 59/100


100%|██████████| 60/60 [00:40<00:00,  1.47it/s]


Epoch [59/100], Loss: 0.6095
Accuracy at epoch 58: {'accuracy': 0.8453608155250549, 'precision': 0.6470588445663452, 'recall': 0.550000011920929, 'f1_score': 0.5945945978164673}
Epoch 60/100


100%|██████████| 60/60 [00:41<00:00,  1.46it/s]


Epoch [60/100], Loss: 0.1875
Accuracy at epoch 59: {'accuracy': 0.8350515365600586, 'precision': 0.6428571343421936, 'recall': 0.44999998807907104, 'f1_score': 0.529411792755127}
Epoch 61/100


100%|██████████| 60/60 [00:41<00:00,  1.46it/s]


Epoch [61/100], Loss: 0.4723
Accuracy at epoch 60: {'accuracy': 0.5773195624351501, 'precision': 0.32203391194343567, 'recall': 0.949999988079071, 'f1_score': 0.4810126721858978}
Epoch 62/100


100%|██████████| 60/60 [00:40<00:00,  1.49it/s]


Epoch [62/100], Loss: 0.2292
Accuracy at epoch 61: {'accuracy': 0.8350515365600586, 'precision': 0.625, 'recall': 0.5, 'f1_score': 0.5555555820465088}
Epoch 63/100


100%|██████████| 60/60 [00:39<00:00,  1.54it/s]


Epoch [63/100], Loss: 0.4672
Accuracy at epoch 62: {'accuracy': 0.8144329786300659, 'precision': 0.5357142686843872, 'recall': 0.75, 'f1_score': 0.625}
Epoch 64/100


100%|██████████| 60/60 [00:38<00:00,  1.54it/s]


Epoch [64/100], Loss: 0.0379
Accuracy at epoch 63: {'accuracy': 0.8453608155250549, 'precision': 0.6666666865348816, 'recall': 0.5, 'f1_score': 0.5714285969734192}
Epoch 65/100


100%|██████████| 60/60 [00:39<00:00,  1.54it/s]


Epoch [65/100], Loss: 0.9037
Accuracy at epoch 64: {'accuracy': 0.8453608155250549, 'precision': 0.6470588445663452, 'recall': 0.550000011920929, 'f1_score': 0.5945945978164673}
Epoch 66/100


  7%|▋         | 4/60 [00:01<00:25,  2.19it/s]


KeyboardInterrupt: 

In [None]:
# ##########################################################################################################################################################################
# # ONLY FOR A VERY GOOD SETUP (not for our poor gpu :( )) 
# # DIVIDE A SEQUENCE INTO GROUPS OF FRMAES, FOR EACH GROUP MAKE A PREDICTION AND FINALLY MAKE A VOTING AMONG THE PREDICTIONS FOR GETTING THE PREVISION OF THE SEQUENCE


# num_epochs = 100

# for epoch in range(num_epochs):
#     print(f'Epoch {epoch+1}/{num_epochs}')
#     model.train()

#     for tensor_images, bboxes, masks, poses, intents in tqdm(train_loader):
#         poses = torch.stack(poses, dim=0)  # Now the shape is (batch_size, 36)
#         poses = poses.squeeze(0)
#         poses = poses.view(len(tensor_images), -1, 36).permute(1,0,2)  # Reshape to (batch_size, numeroFrames, 36)

#         # Move tensors to device
#         tensor_images = torch.stack(tensor_images, dim=1).squeeze(0).permute(0, 1, 2,3).to(device) # Converte la lista di immagini in un tensor
#         masks = torch.stack(masks,dim=1).squeeze(0).float().to(device)  # Converte la lista di maschere in un tensor
#         bboxes = bboxes.to(device)
#         poses = poses.to(device)
#         intents = intents.squeeze(0)[0].to(device)



#         optimizer.zero_grad()

#         outputs_list = []
#         num_frames = 10
#         sub_tensors = tensor_images.unfold(0, num_frames, num_frames)
#         sub_tensors = sub_tensors.contiguous().view(-1, num_frames, 3, 224, 224)
#         tensor_list = [sub_tensor for sub_tensor in sub_tensors]
#         for tensor in tensor_list:
#             output = model(tensor, bboxes, masks, poses)
#             outputs_list.append(output.cpu())


#             del tensor,output
#             torch.cuda.empty_cache()
#         print("hello")
#         outputs = sum(outputs_list)/len(outputs_list)
#         outputs_list.clear()

#         loss = criterion(outputs.to(device), intents.float())
#         loss.backward()

#         optimizer.step()
#         del tensor_images, masks, bboxes, poses, intents, outputs
#         torch.cuda.empty_cache()

#     print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
#     with torch.no_grad():
#         print(f'Accuracy at epoch {epoch}: {evaluate_metrics(model, test_loader, device)}')
        

