# CONFIGURATION:

In [1]:
from jaad_data import JAAD
import torch
from PIL import Image
from torchvision import transforms
from torchvision import models
import matplotlib.pyplot as plt
import network
import openpose
from openpose import model
from openpose import util
from openpose.body import Body
import copy
from tqdm import tqdm
import pickle
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
import torchmetrics
from torcheval.metrics import BinaryAccuracy
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score



In [2]:
JAAD_PATH = '../JAAD'
DEEPLAB_PATH = '../best_deeplabv3plus_resnet101_cityscapes_os16.pth'
SUBSET_PATH = '../subset2'

RESULTS_MASK_SUB = '../masks_results_sub.pkl'
RESULTS_MASK_BIG = '../masks_results_big.pkl'
RESULTS_MASK_BIG_TEST = '../masks_results_big_test.pkl'
RESULTS_MASK_SUB_TEST = '../masks_results_sub_test.pkl'

RESULTS_POSE_BIG = '../pose_results_big.pkl'
RESULTS_POSE_SUB = '../pose_results_sub.pkl'
RESULTS_POSE_BIG_TEST = '../pose_results_big_test.pkl'
RESULTS_POSE_SUB_TEST = '../pose_results_sub_test.pkl'

POSE_PATH = '../body_pose_model.pth'   


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
torch.cuda.empty_cache()


cuda


In [4]:
# CONFIG (change these params for changing the experiment): 
# BiG = True for big dataset, False for small dataset
RUN = True
BIG =False 

In [5]:
if BIG:
    MASK_CMD = RESULTS_MASK_BIG
    POSE_CMD = RESULTS_POSE_BIG
    MASK_CMD_TEST = RESULTS_MASK_BIG_TEST
    POSE_CMD_TEST = RESULTS_POSE_BIG_TEST
    DT_CMD = JAAD_PATH
else:
    MASK_CMD = RESULTS_MASK_SUB
    POSE_CMD = RESULTS_POSE_SUB
    MASK_CMD_TEST = RESULTS_MASK_SUB_TEST
    POSE_CMD_TEST = RESULTS_POSE_SUB_TEST
    DT_CMD = SUBSET_PATH



# DATASET

In [7]:
# Load the JAAD dataset
jaad_dt = JAAD(data_path='../subset2')

data_opts = {
    'fstride': 15,
    'sample_type': 'all'
}

seq_train = jaad_dt.generate_data_trajectory_sequence('train', **data_opts)  
seq_test = jaad_dt.generate_data_trajectory_sequence('test', **data_opts)  

---------------------------------------------------------
Generating action sequence data
fstride: 15
sample_type: all
subset: default
height_rng: [0, inf]
squarify_ratio: 0
data_split_type: default
seq_type: intention
min_track_size: 15
random_params: {'ratios': None, 'val_data': True, 'regen_data': False}
kfold_params: {'num_folds': 5, 'fold': 1}
---------------------------------------------------------
Generating database for jaad
jaad database loaded from c:\Users\jacop\Documents\ComputerVision\subset2\data_cache\jaad_database.pkl
---------------------------------------------------------
Generating intention data
Split: train
Number of pedestrians: 185 
Total number of samples: 60 
---------------------------------------------------------
Generating action sequence data
fstride: 15
sample_type: all
subset: default
height_rng: [0, inf]
squarify_ratio: 0
data_split_type: default
seq_type: intention
min_track_size: 15
random_params: {'ratios': None, 'val_data': True, 'regen_data': Fal

# GLOBAL CONTEXT EXTRACTION:

In [10]:
if RUN:
    deeplab_model = network.modeling.__dict__['deeplabv3plus_resnet101'](num_classes=19)
    deeplab_model.load_state_dict(torch.load(DEEPLAB_PATH)['model_state'])
    deeplab_model.to(device)
    deeplab_model.eval()

In [11]:
#trasformazioni che vengono usate dentro global context (modifica l'input prima che vada in deeplab)

train_transforms = transforms.Compose([
    transforms.Resize((512, 512)),  # Ridimensiona le immagini a 512x512
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

#processa output deeplab
GC_trans = transforms.Compose([
    transforms.Resize((224, 224)),  # Ridimensiona le immagini a 224x224
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# Visualizza la maschera semantica
def visualize_mask(image_path, mask):
    image = Image.open(image_path).convert("RGB")
    #image = image.resize((512, 512))  # Ridimensiona per la visualizzazione
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title("Original Image")
    plt.subplot(1, 2, 2)
    plt.imshow(mask, cmap='jet')
    plt.title("Semantic Mask")
    plt.show()

In [12]:
def get_segmentation_mask(image_path, model, preprocess):
    """ funzione che prende in input le path delle imagini, il modello e la funzione di preprocessamento 
    e restituisce la maschera segmentata dell'immagine resizata a 224x224"""
    
    # Load the image
    input_image = Image.open(image_path).convert("RGB")
    input_tensor = preprocess(input_image).to(device)
    input_batch = input_tensor.unsqueeze(0)  
    
    # Pass the image through the model
    with torch.no_grad():
        output = model(input_batch)
        
    # Check if output is a tensor or a dictionary (auxiliary control)
    if isinstance(output, dict):
        output = output['out'][0]
    elif isinstance(output, torch.Tensor):
        output = output[0]
    else:
        raise ValueError(f"Unexpected output type: {type(output)}")
    
    # Convert the output to a mask
    output_predictions = output.argmax(0)

####################################################################################
    #OPTIONAL FOR DEBUG AND PLOTTING:
    # output_predictions_pic = output_predictions.clone().cpu()
    # visualize_mask(image_path,output_predictions_pic) 
####################################################################################

    # Fix dimensions for the mask and convert to float
    output_predictions = output_predictions.unsqueeze(0).unsqueeze(0).float()

    #convert to RGB and resize to (224,224), because VGG want(3,224,224) as input
    tr = transforms.ToPILImage()
    pic = tr(output_predictions.squeeze(1))
    pic= pic.convert("RGB")
    resized_mask = GC_trans(pic)

    return resized_mask

In [13]:
def process_video_frames(seq_train, model, preprocess):
    """funzione che prende in input la sequenza di training, il modello e la funzione di preprocessamento, restituisce una
    lista di segmentation mask per ogni frame di ogni video della sequenza di training"""
    
    all_masks = []
    for video_frames in tqdm(seq_train['image'], desc="Processing videos"):
        video_masks = []

        for frame_path in tqdm(video_frames, desc="Processing frames", leave=False):
            mask = get_segmentation_mask(frame_path, model, preprocess)
            video_masks.append(mask)
        all_masks.append(video_masks)
    
    return all_masks


In [15]:
if RUN:
    all_video_masks = process_video_frames(seq_train, deeplab_model, train_transforms)
    seq_train['masks'] = all_video_masks
    all_video_masks_test = process_video_frames(seq_test, deeplab_model, train_transforms)
    seq_test['masks'] = all_video_masks_test

    #cleaning for my poor gpu
    del deeplab_model
    torch.cuda.empty_cache()

    # save data in the .pkl files 
    with open(MASK_CMD, 'wb') as f:
        pickle.dump(seq_train['masks'], f)
    with open(MASK_CMD_TEST, 'wb') as f:
        pickle.dump(seq_test['masks'], f)
else:
    #recover data:
    with open(MASK_CMD, 'rb') as f:
        seq_train['masks'] = pickle.load(f)
    with open(MASK_CMD_TEST, 'rb') as f:
        seq_test['masks'] = pickle.load(f)


# LOCAL CONTEXT:

In [None]:
# Trasformazioni per le immaginin che vengono per il local context 
transform_lc = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [16]:
def crop_image_cv2(img, bbox):
    """ funzione che croppa i frames sul bounding boxes, le imagini sono in formato cv2"""
    
    x1, y1, x2, y2 = bbox
    return img[int(y1):int(y2), int(x1):int(x2)]

In [17]:
"""Trasformation for the local context's images, enhance the quality of the images by appling gaussian filter, unsharp mask e bilateral filter"""
all_images = [] #list of the images
for i in tqdm(range(len(seq_train['image'])), desc="Processing videos"):
    aux_list = []
    for j in tqdm(range(len(seq_train['image'][i])), desc="Processing frames", leave=False):
        
        # Open the images from the paths
        img_path = seq_train['image'][i][j]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB
        
        #get the bbox and crop arround it
        bbox = seq_train['bbox'][i][j]
        cropped_images = crop_image_cv2(img, bbox)
        
        # Resize to 224x224 (VGG input size)
        final_image = cv2.resize(cropped_images, (224, 224))
        aux_list.append(final_image)
    all_images.append(aux_list)

Processing videos: 100%|██████████| 60/60 [00:38<00:00,  1.57it/s]


In [19]:
#same of above, but for the test set
all_images_test = []
for i in tqdm(range(len(seq_test['image'])), desc="Processing videos"):
    aux_list = []
    for j in tqdm(range(len(seq_test['image'][i])), desc="Processing frames", leave=False):
        
        # Open the images from the paths
        img_path = seq_test['image'][i][j]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert to RGB

        #get the bbox and crop arround it
        bbox = seq_test['bbox'][i][j]
        cropped_images = crop_image_cv2(img, bbox)
        
        # Resize to 224x224 (VGG input size)
        final_image = cv2.resize(cropped_images, (224, 224))
        aux_list.append(final_image)
    all_images_test.append(aux_list)


Processing videos: 100%|██████████| 97/97 [01:02<00:00,  1.56it/s]


# POSE KEYPOINTS: 

In [21]:
def extract_pose_sequence(frames, body_model):
    """funzione che prende in input i frames e il modello di openpose e restituisce una lista di tensori di pose 
    per ciascuna persona nel tempo, se non prende nessuna posa mette un placeholder"""
    #print(frames)
    pose_sequences = []  # Lista di pose per ciascuna persona nel tempo
    pose_placeholder = torch.zeros((36,), dtype=torch.float32)
    for frame in frames:
        candidate, subset = body_model(frame)

        ###############################################################
        # DEBUG & PLOTTING: Uncomment the following
        canvas = copy.deepcopy(frame)
        canvas = util.draw_bodypose(canvas, candidate, subset)
        plt.imshow(canvas[:, :, [2, 1, 0]])  
        plt.axis('off')
        plt.show()
        ###############################################################

        frame_poses = []
        for person in subset:
            if person[-1] >= 4:  # Almeno 4 punti chiave rilevati
                pose = []
                for i in range(18):
                    if person[i] != -1:
                        x, y = candidate[int(person[i])][:2]
                    else:
                        x, y = -1, -1  # Punti chiave mancanti
                    pose.extend([x, y])
                frame_poses.append(pose)
        if not frame_poses:
            frame_poses = [pose_placeholder.tolist()] #use placeholder if no pose detected (for dimensional consistency)
        pose_sequences.append(frame_poses)

    # Trasponi la lista di liste per ottenere le sequenze temporali per ciascuna persona
    person_pose_sequences = list(map(list, zip(*pose_sequences)))
    person_pose_sequences = [torch.tensor(person_poses, dtype=torch.float32) for person_poses in person_pose_sequences]
    
    return person_pose_sequences

In [22]:

if RUN:
    body_model = Body(POSE_PATH)

    # Caricamento dei frame e estrazione delle pose
    all_poses = []
    all_poses_test = []

    #itera tra i video e prendi i frames
    for pics in tqdm(all_images, desc="Extracting poses from image sequences"):
        #print(len(pics))
        pose_sequences = extract_pose_sequence(pics, body_model)
        all_poses.append(pose_sequences)
        #print(pose_sequences.shape)
    # Aggiungi le pose estratte alla sequenza di allenamento
    seq_train['poses'] = all_poses

    for pics in tqdm(all_images_test, desc="Extracting poses from image sequences of test set"):
        pose_sequences = extract_pose_sequence(pics, body_model)
        all_poses_test.append(pose_sequences)
    # Aggiungi le pose estratte alla sequenza di allenamento
    seq_test['poses'] = all_poses_test
    del body_model
    torch.cuda.empty_cache()

    # save data in the .pkl files 
    with open(POSE_CMD, 'wb') as f:
        pickle.dump(seq_train['poses'], f)
    with open(POSE_CMD_TEST, 'wb') as f:
        pickle.dump(seq_test['poses'], f)
else:
    #recover data:
    with open(POSE_CMD, 'rb') as f:
        seq_train['poses'] = pickle.load(f)
    with open(POSE_CMD_TEST, 'rb') as f:
        seq_test['poses'] = pickle.load(f)


# MODEL:

In [24]:
class VisionBranchLocal(torch.nn.Module):
    """definizione del modello per il local context, prende in input le immagini croppate e restituisce un tensore,
    le immagini croppate vengono fatte passare dentro una VGG16, una GRU e un attention block"""

    def __init__(self, vgg16):
        super(VisionBranchLocal, self).__init__()
        self.vgg16 = vgg16
        self.avgpool = torch.nn.AvgPool2d(kernel_size=14)  # Pooling layer con kernel 14x14
        self.gru = torch.nn.GRU(input_size=512, hidden_size=256, num_layers=2, batch_first=True)
        self.fc = torch.nn.Linear(256, 2)    # Fully connected layer
        self.attn = torch.nn.Linear(256, 1)  # Attention layer
        self.tanh = torch.nn.Sigmoid()
        #self.tanh = torch.nn.Tanh() #another try
    def forward(self, cropped_images):
        seq_len, c, h, w = cropped_images.size()
        
        # Estrai feature dalle immagini con VGG16
        vgg_features = []
        for i in range(seq_len):            

            img = cropped_images[i]            
            vgg_feat_img = self.vgg16.features(img)
            pooled_feat_img = self.avgpool(vgg_feat_img)  # Applica il pooling
            vgg_feat_img = pooled_feat_img.view(pooled_feat_img.size(0), -1)  # Flatten features
            vgg_features.append(vgg_feat_img)
        
        vgg_features = torch.stack(vgg_features, dim=1).permute(2,1,0)

        #applica la gru
        gru_out, _ = self.gru(vgg_features)

        #applica l'attention
        attn_weights = torch.softmax(self.attn(gru_out), dim=1)
        context_vector = torch.sum(attn_weights * gru_out, dim=1)
        
 
        out = self.tanh((context_vector))
        return out


In [25]:
class VisionBranchGlobal(torch.nn.Module):
    """definizione del modello per il global context, prende in input le maskere semantiche e restituisce un tensore,
    le maskere semantiche vengono fatte passare dentro una VGG16, una GRU e un attention block"""

    def __init__(self, vgg16):
        super(VisionBranchGlobal, self).__init__()
        self.vgg16 = vgg16
        self.avgpool = torch.nn.AvgPool2d(kernel_size=14)  # Pooling layer con kernel 14x14
        self.gru = torch.nn.GRU(input_size=512, hidden_size=256, num_layers=2, batch_first=True)
        self.fc = torch.nn.Linear(256, 2)    # Fully connected layer
        self.attn = torch.nn.Linear(256, 1)  # Attention layer
        self.tanh = torch.nn.Tanh()

    def forward(self, masks):
        seq_len = masks.size()[0]

        # Estrai feature dalle immagini con VGG16
        vgg_features = []
        for i in range(seq_len):            
            img = masks[i]            
            vgg_feat_img = self.vgg16.features(img)
            pooled_feat_img = self.avgpool(vgg_feat_img)  # Applica il pooling
            vgg_feat_img = pooled_feat_img.view(pooled_feat_img.size(0), -1)  # Flatten features
            vgg_features.append(vgg_feat_img)
        
        vgg_features = torch.stack(vgg_features, dim=1).permute(2,1,0)

        gru_out, _ = self.gru(vgg_features)
        attn_scores = self.attn(gru_out)  
        
        attn_weights = torch.softmax(attn_scores, dim=1)  
        
        context_vector = torch.sum(attn_weights * gru_out, dim=1)  
        
        out = self.tanh((context_vector))
        return out


In [26]:
class NVisionBranch(torch.nn.Module):
    """classe relativa al non-vision brach, prende in input le pose e le bbox in formato tensore, esse vengono fatte passare
    dentro una GRU e un attention block, l'ordine influenza la prestazioni"""
    
    def __init__(self):
        super(NVisionBranch, self).__init__()
        self.gru = torch.nn.GRU(input_size=36, hidden_size=256, num_layers=2, batch_first=True)
        self.gru2 = torch.nn.GRU(input_size=256+4, hidden_size=256, num_layers=2, batch_first=True)
        self.attn = torch.nn.Linear(256, 1)  # Attention layer
        self.tanh = torch.nn.Tanh()

    def forward(self, poses,bbox):
        gru_out, _ = self.gru(poses)
        LP = torch.cat((gru_out,bbox),dim=-1)
        gru_out, _ = self.gru2(LP)

        # Attention mechanism
        attn_scores = self.attn(gru_out) 
        
        attn_weights = torch.softmax(attn_scores, dim=1)  
        
        context_vector = torch.sum(attn_weights * gru_out, dim=1)  
        
        out = self.tanh(context_vector)
        return out


In [28]:
class PedestrianIntentModel(torch.nn.Module):
    """definizione del modello finale, prende in input gli output del vision brach e del non vision branch, viene fatta
    una concatenazione che in seguito passa dentro un attention e un fully connected layer, l'output è la predizione,
    (passa non passa)"""

    def __init__(self, vision_branch_local,vision_branch_global,non_vision_branch):
        super(PedestrianIntentModel, self).__init__()
        self.vision_branch_local = vision_branch_local #output of the vision branch local
        self.vision_branch_global = vision_branch_global #output of the vision branch global
        self.non_vision_branch = non_vision_branch #output of the non vision branch
        self.attn = torch.nn.Linear(768, 768)  # Attention layer

        self.fc1 = torch.nn.Linear(768, 256) # fully connected layer
        self.fc2 = torch.nn.Linear(256,1) # Output: crossing or not crossing

        
    def forward(self, cropped_images, bboxes, masks, poses):
        vision_out_local = self.vision_branch_local(cropped_images)
        vision_out_global = self.vision_branch_global(masks)
        non_vision_out = self.non_vision_branch(poses, bboxes)

        vision_out = torch.cat((vision_out_local, vision_out_global), dim=-1)
        final_fusion = torch.cat((vision_out, non_vision_out), dim=-1)

        attn_scores = self.attn(final_fusion)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        context_vector = torch.sum(attn_weights * final_fusion, dim=0)

        out = (self.fc2(self.fc1(context_vector))) #raw output

        return out


In [29]:
class VGG16_FeatureExtractor(torch.nn.Module):

    def __init__(self):
        super(VGG16_FeatureExtractor, self).__init__()
        self.features = torch.nn.Sequential(*list(vgg16.features.children())[:24]) # block4_pool è il 24° livello
    def forward(self, x):
        x = self.features(x)
        return x

In [30]:
# Carica il modello VGG19 pre-addestrato
vgg16 = models.vgg16(pretrained=True)

#cut the model at the 24th layer:
vgg16_fe = VGG16_FeatureExtractor()
vgg16_fe


# define the models of each branches
model_local = VisionBranchLocal(vgg16_fe).to(device)
model_global = VisionBranchGlobal(vgg16_fe).to(device)
model_non_vision = NVisionBranch().to(device)
model = PedestrianIntentModel(model_local,model_global,model_non_vision).to(device)




# DATASET & DATALOADER

In [31]:
class JAADDataset(Dataset):
    """definizione della classe per il custom dataset, prende in input la seq_train, le immagini e le trasformazioni,
    restituisce il tensore delle immagini croppate, le bboxes, le maschere, le pose e la lables"""

    def __init__(self, seq_data, all_images, transform=None):
        self.seq_data = seq_data
        self.all_images = all_images
        self.transform = transform

    def __len__(self):
        return len(self.seq_data['image'])

    def __getitem__(self, idx):
        bbox_sequence = self.seq_data['bbox'][idx]
        masks = self.seq_data['masks'][idx]
        poses = self.seq_data['poses'][idx]
        all_images = self.all_images[idx]

        if self.transform:
            tensor_images = [self.transform(img) for img in all_images]

        bboxes = torch.tensor(self.seq_data['bbox'][idx], dtype=torch.float32)
        intents = torch.tensor(self.seq_data['intent'][idx], dtype=torch.float32)
        return  tensor_images, bboxes, masks, poses, intents

In [32]:
train_dataset = JAADDataset(seq_train,all_images=all_images, transform=transform_lc)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_dataset = JAADDataset(seq_test,all_images=all_images_test, transform=transform_lc)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

# TRAINING:

In [33]:
criterion = torch.nn.BCEWithLogitsLoss() #input are raw output
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001) 

In [37]:
def evaluate_metrics(net, loader, device):
    acc = BinaryAccuracy().to(device)
    precision = BinaryPrecision().to(device)
    recall = BinaryRecall().to(device)
    f1_score = BinaryF1Score().to(device)
    
    model.eval()

    for tensor_images, bboxes, masks, poses, intents in loader:
        poses = torch.stack(poses, dim=0)  
        poses = poses.squeeze(0)
        poses = poses.view(len(tensor_images), -1, 36).permute(1, 0, 2) 
        
        # Move tensors to device
        tensor_images = torch.stack(tensor_images, dim=1).squeeze(0).permute(0, 1, 2, 3).to(device)  # Convert image list to tensor
        masks = torch.stack(masks, dim=1).squeeze(0).float().to(device)  # Convert mask list to tensor
        bboxes = bboxes.to(device)
        poses = poses.to(device)
        intents = intents.squeeze(0)[0].to(device)
        
        ypred = net(tensor_images, bboxes, masks, poses)

        # Move tensors back to CPU and clean up memory
        tensor_images.cpu()
        masks.cpu()
        bboxes.cpu()
        poses.cpu()
        del tensor_images, bboxes, masks, poses
        torch.cuda.empty_cache()
        
        # Update metrics
        acc.update(ypred, intents)
        precision.update(ypred, intents)
        recall.update(ypred, intents)
        f1_score.update(ypred, intents)
        
        intents.cpu()
        del intents
        torch.cuda.empty_cache()
    
    metrics = {
        'accuracy': acc.compute().item(),
        'precision': precision.compute().item(),
        'recall': recall.compute().item(),
        'f1_score': f1_score.compute().item()
    }
    
    # Reset metrics for the next evaluation
    acc.reset()
    precision.reset()
    recall.reset()
    f1_score.reset()
    
    return metrics


In [38]:

num_epochs = 100

# free the memory
torch.cuda.empty_cache()

for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    model.train()

    for tensor_images, bboxes, masks, poses, intents in tqdm(train_loader):

        poses = torch.stack(poses, dim=0)  # Now the shape is (batch_size, 36)
        poses = poses.squeeze(0)
        poses = poses.view(len(tensor_images), -1, 36).permute(1,0,2)  # Reshape to (batch_size, numeroFrames, 36)


        # Move tensors to device
        tensor_images = torch.stack(tensor_images, dim=1).squeeze(0).permute(0, 1, 2,3).to(device)  # Converte la lista di immagini in un tensor
        masks = torch.stack(masks,dim=1).squeeze(0).float().to(device)  # Converte la lista di maschere in un tensor
        bboxes = bboxes.to(device)
        poses = poses.to(device)
        intents = intents.squeeze(0)[0].to(device)


        optimizer.zero_grad()

        outputs = model(tensor_images, bboxes, masks, poses)
        tensor_images.cpu()
        bboxes.cpu()
        poses.cpu()
        loss = criterion(outputs, intents.float())
        loss.backward()


        optimizer.step()
        del tensor_images, masks, bboxes, poses, intents, outputs
        torch.cuda.empty_cache()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
    with torch.no_grad():
        print(f'Accuracy at epoch {epoch}: {evaluate_metrics(model, test_loader, device)}')
        



Epoch 1/100


100%|██████████| 60/60 [00:34<00:00,  1.72it/s]


Epoch [1/100], Loss: 0.6502
Accuracy at epoch 0: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 2/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [2/100], Loss: 0.6180
Accuracy at epoch 1: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 3/100


100%|██████████| 60/60 [00:33<00:00,  1.80it/s]


Epoch [3/100], Loss: 0.8512
Accuracy at epoch 2: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 4/100


100%|██████████| 60/60 [00:33<00:00,  1.78it/s]


Epoch [4/100], Loss: 0.5915
Accuracy at epoch 3: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 5/100


100%|██████████| 60/60 [00:34<00:00,  1.74it/s]


Epoch [5/100], Loss: 0.8073
Accuracy at epoch 4: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 6/100


100%|██████████| 60/60 [00:34<00:00,  1.72it/s]


Epoch [6/100], Loss: 0.8180
Accuracy at epoch 5: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 7/100


100%|██████████| 60/60 [00:35<00:00,  1.69it/s]


Epoch [7/100], Loss: 0.5971
Accuracy at epoch 6: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 8/100


100%|██████████| 60/60 [00:36<00:00,  1.63it/s]


Epoch [8/100], Loss: 0.5669
Accuracy at epoch 7: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 9/100


100%|██████████| 60/60 [00:35<00:00,  1.67it/s]


Epoch [9/100], Loss: 0.8473
Accuracy at epoch 8: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 10/100


100%|██████████| 60/60 [00:33<00:00,  1.78it/s]


Epoch [10/100], Loss: 0.5608
Accuracy at epoch 9: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 11/100


100%|██████████| 60/60 [00:33<00:00,  1.77it/s]


Epoch [11/100], Loss: 0.5225
Accuracy at epoch 10: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 12/100


100%|██████████| 60/60 [00:33<00:00,  1.78it/s]


Epoch [12/100], Loss: 0.5606
Accuracy at epoch 11: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 13/100


100%|██████████| 60/60 [00:33<00:00,  1.77it/s]


Epoch [13/100], Loss: 0.4781
Accuracy at epoch 12: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 14/100


100%|██████████| 60/60 [00:34<00:00,  1.76it/s]


Epoch [14/100], Loss: 0.8471
Accuracy at epoch 13: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 15/100


100%|██████████| 60/60 [00:34<00:00,  1.73it/s]


Epoch [15/100], Loss: 0.9419
Accuracy at epoch 14: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 16/100


100%|██████████| 60/60 [00:35<00:00,  1.69it/s]


Epoch [16/100], Loss: 0.5081
Accuracy at epoch 15: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 17/100


100%|██████████| 60/60 [00:36<00:00,  1.63it/s]


Epoch [17/100], Loss: 0.8426
Accuracy at epoch 16: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 18/100


100%|██████████| 60/60 [00:34<00:00,  1.75it/s]


Epoch [18/100], Loss: 0.4990
Accuracy at epoch 17: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 19/100


100%|██████████| 60/60 [00:33<00:00,  1.78it/s]


Epoch [19/100], Loss: 0.6315
Accuracy at epoch 18: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 20/100


100%|██████████| 60/60 [00:33<00:00,  1.77it/s]


Epoch [20/100], Loss: 1.1430
Accuracy at epoch 19: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 21/100


100%|██████████| 60/60 [00:33<00:00,  1.78it/s]


Epoch [21/100], Loss: 0.5024
Accuracy at epoch 20: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 22/100


100%|██████████| 60/60 [00:33<00:00,  1.78it/s]


Epoch [22/100], Loss: 0.6764
Accuracy at epoch 21: {'accuracy': 0.8041236996650696, 'precision': 0.6666666865348816, 'recall': 0.10000000149011612, 'f1_score': 0.17391304671764374}
Epoch 23/100


100%|██████████| 60/60 [00:33<00:00,  1.77it/s]


Epoch [23/100], Loss: 0.6947
Accuracy at epoch 22: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 24/100


100%|██████████| 60/60 [00:33<00:00,  1.77it/s]


Epoch [24/100], Loss: 0.4846
Accuracy at epoch 23: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 25/100


100%|██████████| 60/60 [00:34<00:00,  1.76it/s]


Epoch [25/100], Loss: 1.0578
Accuracy at epoch 24: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 26/100


100%|██████████| 60/60 [00:33<00:00,  1.79it/s]


Epoch [26/100], Loss: 0.4866
Accuracy at epoch 25: {'accuracy': 0.7731958627700806, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 27/100


100%|██████████| 60/60 [00:33<00:00,  1.82it/s]


Epoch [27/100], Loss: 0.6344
Accuracy at epoch 26: {'accuracy': 0.7835051417350769, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 28/100


100%|██████████| 60/60 [00:32<00:00,  1.82it/s]


Epoch [28/100], Loss: 0.3366
Accuracy at epoch 27: {'accuracy': 0.7938144207000732, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 29/100


100%|██████████| 60/60 [00:33<00:00,  1.82it/s]


Epoch [29/100], Loss: 1.2039
Accuracy at epoch 28: {'accuracy': 0.7628865838050842, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 30/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [30/100], Loss: 0.5027
Accuracy at epoch 29: {'accuracy': 0.7835051417350769, 'precision': 0.3333333432674408, 'recall': 0.05000000074505806, 'f1_score': 0.08695652335882187}
Epoch 31/100


100%|██████████| 60/60 [00:32<00:00,  1.83it/s]


Epoch [31/100], Loss: 0.3148
Accuracy at epoch 30: {'accuracy': 0.7731958627700806, 'precision': 0.25, 'recall': 0.05000000074505806, 'f1_score': 0.0833333358168602}
Epoch 32/100


100%|██████████| 60/60 [00:32<00:00,  1.84it/s]


Epoch [32/100], Loss: 0.6644
Accuracy at epoch 31: {'accuracy': 0.7835051417350769, 'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}
Epoch 33/100


100%|██████████| 60/60 [00:32<00:00,  1.84it/s]


Epoch [33/100], Loss: 0.9297
Accuracy at epoch 32: {'accuracy': 0.7938144207000732, 'precision': 0.5, 'recall': 0.550000011920929, 'f1_score': 0.523809552192688}
Epoch 34/100


100%|██████████| 60/60 [00:33<00:00,  1.80it/s]


Epoch [34/100], Loss: 0.3223
Accuracy at epoch 33: {'accuracy': 0.8144329786300659, 'precision': 0.625, 'recall': 0.25, 'f1_score': 0.3571428656578064}
Epoch 35/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [35/100], Loss: 0.3265
Accuracy at epoch 34: {'accuracy': 0.7938144207000732, 'precision': 0.5, 'recall': 0.15000000596046448, 'f1_score': 0.23076923191547394}
Epoch 36/100


100%|██████████| 60/60 [00:32<00:00,  1.83it/s]


Epoch [36/100], Loss: 0.4487
Accuracy at epoch 35: {'accuracy': 0.7835051417350769, 'precision': 0.4444444477558136, 'recall': 0.20000000298023224, 'f1_score': 0.27586206793785095}
Epoch 37/100


100%|██████████| 60/60 [00:32<00:00,  1.84it/s]


Epoch [37/100], Loss: 0.1894
Accuracy at epoch 36: {'accuracy': 0.7628865838050842, 'precision': 0.3636363744735718, 'recall': 0.20000000298023224, 'f1_score': 0.25806450843811035}
Epoch 38/100


100%|██████████| 60/60 [00:33<00:00,  1.77it/s]


Epoch [38/100], Loss: 0.8527
Accuracy at epoch 37: {'accuracy': 0.7835051417350769, 'precision': 0.4000000059604645, 'recall': 0.10000000149011612, 'f1_score': 0.1599999964237213}
Epoch 39/100


100%|██████████| 60/60 [00:34<00:00,  1.76it/s]


Epoch [39/100], Loss: 0.3008
Accuracy at epoch 38: {'accuracy': 0.7628865838050842, 'precision': 0.3333333432674408, 'recall': 0.15000000596046448, 'f1_score': 0.2068965584039688}
Epoch 40/100


100%|██████████| 60/60 [00:32<00:00,  1.84it/s]


Epoch [40/100], Loss: 0.3022
Accuracy at epoch 39: {'accuracy': 0.7938144207000732, 'precision': 0.5, 'recall': 0.30000001192092896, 'f1_score': 0.375}
Epoch 41/100


100%|██████████| 60/60 [00:32<00:00,  1.84it/s]


Epoch [41/100], Loss: 0.6133
Accuracy at epoch 40: {'accuracy': 0.7628865838050842, 'precision': 0.4285714328289032, 'recall': 0.44999998807907104, 'f1_score': 0.4390243887901306}
Epoch 42/100


100%|██████████| 60/60 [00:32<00:00,  1.84it/s]


Epoch [42/100], Loss: 0.4352
Accuracy at epoch 41: {'accuracy': 0.7628865838050842, 'precision': 0.4444444477558136, 'recall': 0.6000000238418579, 'f1_score': 0.5106382966041565}
Epoch 43/100


100%|██████████| 60/60 [00:32<00:00,  1.84it/s]


Epoch [43/100], Loss: 1.0539
Accuracy at epoch 42: {'accuracy': 0.7938144207000732, 'precision': 0.5, 'recall': 0.05000000074505806, 'f1_score': 0.09090909361839294}
Epoch 44/100


100%|██████████| 60/60 [00:32<00:00,  1.84it/s]


Epoch [44/100], Loss: 0.2129
Accuracy at epoch 43: {'accuracy': 0.8041236996650696, 'precision': 0.5714285969734192, 'recall': 0.20000000298023224, 'f1_score': 0.29629629850387573}
Epoch 45/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [45/100], Loss: 0.7359
Accuracy at epoch 44: {'accuracy': 0.7938144207000732, 'precision': 0.5, 'recall': 0.699999988079071, 'f1_score': 0.5833333134651184}
Epoch 46/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [46/100], Loss: 0.2407
Accuracy at epoch 45: {'accuracy': 0.7731958627700806, 'precision': 0.3333333432674408, 'recall': 0.10000000149011612, 'f1_score': 0.1538461595773697}
Epoch 47/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [47/100], Loss: 0.7114
Accuracy at epoch 46: {'accuracy': 0.8041236996650696, 'precision': 0.52173912525177, 'recall': 0.6000000238418579, 'f1_score': 0.5581395626068115}
Epoch 48/100


100%|██████████| 60/60 [00:33<00:00,  1.79it/s]


Epoch [48/100], Loss: 0.4545
Accuracy at epoch 47: {'accuracy': 0.8247422575950623, 'precision': 0.6666666865348816, 'recall': 0.30000001192092896, 'f1_score': 0.4137931168079376}
Epoch 49/100


100%|██████████| 60/60 [00:34<00:00,  1.76it/s]


Epoch [49/100], Loss: 0.3808
Accuracy at epoch 48: {'accuracy': 0.7938144207000732, 'precision': 0.5, 'recall': 0.15000000596046448, 'f1_score': 0.23076923191547394}
Epoch 50/100


100%|██████████| 60/60 [00:34<00:00,  1.76it/s]


Epoch [50/100], Loss: 0.8247
Accuracy at epoch 49: {'accuracy': 0.7319587469100952, 'precision': 0.42105263471603394, 'recall': 0.800000011920929, 'f1_score': 0.5517241358757019}
Epoch 51/100


100%|██████████| 60/60 [00:34<00:00,  1.74it/s]


Epoch [51/100], Loss: 0.1209
Accuracy at epoch 50: {'accuracy': 0.8453608155250549, 'precision': 0.7272727489471436, 'recall': 0.4000000059604645, 'f1_score': 0.5161290168762207}
Epoch 52/100


100%|██████████| 60/60 [00:34<00:00,  1.76it/s]


Epoch [52/100], Loss: 0.6077
Accuracy at epoch 51: {'accuracy': 0.8144329786300659, 'precision': 0.5625, 'recall': 0.44999998807907104, 'f1_score': 0.5}
Epoch 53/100


100%|██████████| 60/60 [00:34<00:00,  1.76it/s]


Epoch [53/100], Loss: 0.7142
Accuracy at epoch 52: {'accuracy': 0.8350515365600586, 'precision': 0.6000000238418579, 'recall': 0.6000000238418579, 'f1_score': 0.6000000238418579}
Epoch 54/100


100%|██████████| 60/60 [00:34<00:00,  1.73it/s]


Epoch [54/100], Loss: 1.3747
Accuracy at epoch 53: {'accuracy': 0.7525773048400879, 'precision': 0.40909090638160706, 'recall': 0.44999998807907104, 'f1_score': 0.4285714328289032}
Epoch 55/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [55/100], Loss: 0.2539
Accuracy at epoch 54: {'accuracy': 0.8041236996650696, 'precision': 0.5454545617103577, 'recall': 0.30000001192092896, 'f1_score': 0.3870967626571655}
Epoch 56/100


100%|██████████| 60/60 [00:35<00:00,  1.70it/s]


Epoch [56/100], Loss: 2.0380
Accuracy at epoch 55: {'accuracy': 0.8247422575950623, 'precision': 0.6153846383094788, 'recall': 0.4000000059604645, 'f1_score': 0.4848484992980957}
Epoch 57/100


100%|██████████| 60/60 [00:34<00:00,  1.75it/s]


Epoch [57/100], Loss: 0.0635
Accuracy at epoch 56: {'accuracy': 0.7835051417350769, 'precision': 0.4545454680919647, 'recall': 0.25, 'f1_score': 0.32258063554763794}
Epoch 58/100


100%|██████████| 60/60 [00:34<00:00,  1.75it/s]


Epoch [58/100], Loss: 0.3763
Accuracy at epoch 57: {'accuracy': 0.8144329786300659, 'precision': 0.6000000238418579, 'recall': 0.30000001192092896, 'f1_score': 0.4000000059604645}
Epoch 59/100


100%|██████████| 60/60 [00:34<00:00,  1.73it/s]


Epoch [59/100], Loss: 0.4422
Accuracy at epoch 58: {'accuracy': 0.8144329786300659, 'precision': 0.550000011920929, 'recall': 0.550000011920929, 'f1_score': 0.550000011920929}
Epoch 60/100


100%|██████████| 60/60 [00:34<00:00,  1.75it/s]


Epoch [60/100], Loss: 0.3904
Accuracy at epoch 59: {'accuracy': 0.8350515365600586, 'precision': 0.625, 'recall': 0.5, 'f1_score': 0.5555555820465088}
Epoch 61/100


100%|██████████| 60/60 [00:34<00:00,  1.72it/s]


Epoch [61/100], Loss: 0.3187
Accuracy at epoch 60: {'accuracy': 0.8453608155250549, 'precision': 0.8571428656578064, 'recall': 0.30000001192092896, 'f1_score': 0.4444444477558136}
Epoch 62/100


100%|██████████| 60/60 [00:34<00:00,  1.72it/s]


Epoch [62/100], Loss: 1.6771
Accuracy at epoch 61: {'accuracy': 0.8144329786300659, 'precision': 0.5555555820465088, 'recall': 0.5, 'f1_score': 0.5263158082962036}
Epoch 63/100


100%|██████████| 60/60 [00:34<00:00,  1.73it/s]


Epoch [63/100], Loss: 0.9798
Accuracy at epoch 62: {'accuracy': 0.8144329786300659, 'precision': 0.6000000238418579, 'recall': 0.30000001192092896, 'f1_score': 0.4000000059604645}
Epoch 64/100


100%|██████████| 60/60 [00:34<00:00,  1.73it/s]


Epoch [64/100], Loss: 0.3892
Accuracy at epoch 63: {'accuracy': 0.8144329786300659, 'precision': 0.550000011920929, 'recall': 0.550000011920929, 'f1_score': 0.550000011920929}
Epoch 65/100


100%|██████████| 60/60 [00:34<00:00,  1.73it/s]


Epoch [65/100], Loss: 0.0659
Accuracy at epoch 64: {'accuracy': 0.8144329786300659, 'precision': 0.6666666865348816, 'recall': 0.20000000298023224, 'f1_score': 0.3076923191547394}
Epoch 66/100


100%|██████████| 60/60 [00:34<00:00,  1.75it/s]


Epoch [66/100], Loss: 0.6174
Accuracy at epoch 65: {'accuracy': 0.8350515365600586, 'precision': 0.5625, 'recall': 0.8999999761581421, 'f1_score': 0.692307710647583}
Epoch 67/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [67/100], Loss: 0.2225
Accuracy at epoch 66: {'accuracy': 0.7731958627700806, 'precision': 0.4000000059604645, 'recall': 0.20000000298023224, 'f1_score': 0.2666666805744171}
Epoch 68/100


100%|██████████| 60/60 [00:34<00:00,  1.73it/s]


Epoch [68/100], Loss: 0.0816
Accuracy at epoch 67: {'accuracy': 0.8453608155250549, 'precision': 0.7777777910232544, 'recall': 0.3499999940395355, 'f1_score': 0.48275861144065857}
Epoch 69/100


100%|██████████| 60/60 [00:34<00:00,  1.76it/s]


Epoch [69/100], Loss: 2.1938
Accuracy at epoch 68: {'accuracy': 0.7628865838050842, 'precision': 0.38461539149284363, 'recall': 0.25, 'f1_score': 0.3030303120613098}
Epoch 70/100


100%|██████████| 60/60 [00:33<00:00,  1.77it/s]


Epoch [70/100], Loss: 0.2916
Accuracy at epoch 69: {'accuracy': 0.8247422575950623, 'precision': 0.6363636255264282, 'recall': 0.3499999940395355, 'f1_score': 0.4516128897666931}
Epoch 71/100


100%|██████████| 60/60 [00:33<00:00,  1.77it/s]


Epoch [71/100], Loss: 0.2824
Accuracy at epoch 70: {'accuracy': 0.7835051417350769, 'precision': 0.4444444477558136, 'recall': 0.20000000298023224, 'f1_score': 0.27586206793785095}
Epoch 72/100


100%|██████████| 60/60 [00:34<00:00,  1.75it/s]


Epoch [72/100], Loss: 0.2455
Accuracy at epoch 71: {'accuracy': 0.7938144207000732, 'precision': 0.5, 'recall': 0.20000000298023224, 'f1_score': 0.2857142984867096}
Epoch 73/100


100%|██████████| 60/60 [00:34<00:00,  1.76it/s]


Epoch [73/100], Loss: 0.1113
Accuracy at epoch 72: {'accuracy': 0.8144329786300659, 'precision': 0.6666666865348816, 'recall': 0.20000000298023224, 'f1_score': 0.3076923191547394}
Epoch 74/100


100%|██████████| 60/60 [00:34<00:00,  1.76it/s]


Epoch [74/100], Loss: 0.7562
Accuracy at epoch 73: {'accuracy': 0.8453608155250549, 'precision': 0.8571428656578064, 'recall': 0.30000001192092896, 'f1_score': 0.4444444477558136}
Epoch 75/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [75/100], Loss: 0.1709
Accuracy at epoch 74: {'accuracy': 0.8144329786300659, 'precision': 0.5833333134651184, 'recall': 0.3499999940395355, 'f1_score': 0.4375}
Epoch 76/100


100%|██████████| 60/60 [00:33<00:00,  1.77it/s]


Epoch [76/100], Loss: 0.4440
Accuracy at epoch 75: {'accuracy': 0.8144329786300659, 'precision': 0.5714285969734192, 'recall': 0.4000000059604645, 'f1_score': 0.47058823704719543}
Epoch 77/100


100%|██████████| 60/60 [00:34<00:00,  1.76it/s]


Epoch [77/100], Loss: 0.3845
Accuracy at epoch 76: {'accuracy': 0.8144329786300659, 'precision': 0.5714285969734192, 'recall': 0.4000000059604645, 'f1_score': 0.47058823704719543}
Epoch 78/100


100%|██████████| 60/60 [00:34<00:00,  1.76it/s]


Epoch [78/100], Loss: 1.2435
Accuracy at epoch 77: {'accuracy': 0.7835051417350769, 'precision': 0.4615384638309479, 'recall': 0.30000001192092896, 'f1_score': 0.3636363744735718}
Epoch 79/100


100%|██████████| 60/60 [00:34<00:00,  1.75it/s]


Epoch [79/100], Loss: 0.5311
Accuracy at epoch 78: {'accuracy': 0.7835051417350769, 'precision': 0.4615384638309479, 'recall': 0.30000001192092896, 'f1_score': 0.3636363744735718}
Epoch 80/100


100%|██████████| 60/60 [00:34<00:00,  1.74it/s]


Epoch [80/100], Loss: 0.1764
Accuracy at epoch 79: {'accuracy': 0.8556700944900513, 'precision': 0.6363636255264282, 'recall': 0.699999988079071, 'f1_score': 0.6666666865348816}
Epoch 81/100


100%|██████████| 60/60 [00:34<00:00,  1.73it/s]


Epoch [81/100], Loss: 0.1816
Accuracy at epoch 80: {'accuracy': 0.8144329786300659, 'precision': 0.5555555820465088, 'recall': 0.5, 'f1_score': 0.5263158082962036}
Epoch 82/100


100%|██████████| 60/60 [00:34<00:00,  1.72it/s]


Epoch [82/100], Loss: 0.2335
Accuracy at epoch 81: {'accuracy': 0.8041236996650696, 'precision': 0.5384615659713745, 'recall': 0.3499999940395355, 'f1_score': 0.42424243688583374}
Epoch 83/100


100%|██████████| 60/60 [00:34<00:00,  1.74it/s]


Epoch [83/100], Loss: 0.2333
Accuracy at epoch 82: {'accuracy': 0.8247422575950623, 'precision': 0.6153846383094788, 'recall': 0.4000000059604645, 'f1_score': 0.4848484992980957}
Epoch 84/100


100%|██████████| 60/60 [00:34<00:00,  1.73it/s]


Epoch [84/100], Loss: 0.0509
Accuracy at epoch 83: {'accuracy': 0.8144329786300659, 'precision': 0.550000011920929, 'recall': 0.550000011920929, 'f1_score': 0.550000011920929}
Epoch 85/100


100%|██████████| 60/60 [00:34<00:00,  1.72it/s]


Epoch [85/100], Loss: 0.1288
Accuracy at epoch 84: {'accuracy': 0.8247422575950623, 'precision': 0.6153846383094788, 'recall': 0.4000000059604645, 'f1_score': 0.4848484992980957}
Epoch 86/100


100%|██████████| 60/60 [00:35<00:00,  1.69it/s]


Epoch [86/100], Loss: 0.6130
Accuracy at epoch 85: {'accuracy': 0.8350515365600586, 'precision': 0.5625, 'recall': 0.8999999761581421, 'f1_score': 0.692307710647583}
Epoch 87/100


100%|██████████| 60/60 [00:33<00:00,  1.77it/s]


Epoch [87/100], Loss: 0.2127
Accuracy at epoch 86: {'accuracy': 0.8144329786300659, 'precision': 0.5714285969734192, 'recall': 0.4000000059604645, 'f1_score': 0.47058823704719543}
Epoch 88/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [88/100], Loss: 1.0294
Accuracy at epoch 87: {'accuracy': 0.8350515365600586, 'precision': 0.6000000238418579, 'recall': 0.6000000238418579, 'f1_score': 0.6000000238418579}
Epoch 89/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [89/100], Loss: 0.2884
Accuracy at epoch 88: {'accuracy': 0.8247422575950623, 'precision': 0.5789473652839661, 'recall': 0.550000011920929, 'f1_score': 0.5641025900840759}
Epoch 90/100


100%|██████████| 60/60 [00:33<00:00,  1.79it/s]


Epoch [90/100], Loss: 1.1629
Accuracy at epoch 89: {'accuracy': 0.876288652420044, 'precision': 0.6538461446762085, 'recall': 0.8500000238418579, 'f1_score': 0.739130437374115}
Epoch 91/100


100%|██████████| 60/60 [00:33<00:00,  1.77it/s]


Epoch [91/100], Loss: 0.3006
Accuracy at epoch 90: {'accuracy': 0.8144329786300659, 'precision': 0.5555555820465088, 'recall': 0.5, 'f1_score': 0.5263158082962036}
Epoch 92/100


100%|██████████| 60/60 [00:33<00:00,  1.80it/s]


Epoch [92/100], Loss: 0.4685
Accuracy at epoch 91: {'accuracy': 0.8247422575950623, 'precision': 0.5789473652839661, 'recall': 0.550000011920929, 'f1_score': 0.5641025900840759}
Epoch 93/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [93/100], Loss: 1.5025
Accuracy at epoch 92: {'accuracy': 0.8247422575950623, 'precision': 0.5882353186607361, 'recall': 0.5, 'f1_score': 0.5405405163764954}
Epoch 94/100


100%|██████████| 60/60 [00:32<00:00,  1.83it/s]


Epoch [94/100], Loss: 0.0547
Accuracy at epoch 93: {'accuracy': 0.8247422575950623, 'precision': 0.6153846383094788, 'recall': 0.4000000059604645, 'f1_score': 0.4848484992980957}
Epoch 95/100


100%|██████████| 60/60 [00:32<00:00,  1.83it/s]


Epoch [95/100], Loss: 0.1225
Accuracy at epoch 94: {'accuracy': 0.8041236996650696, 'precision': 0.6000000238418579, 'recall': 0.15000000596046448, 'f1_score': 0.23999999463558197}
Epoch 96/100


100%|██████████| 60/60 [00:32<00:00,  1.83it/s]


Epoch [96/100], Loss: 0.2501
Accuracy at epoch 95: {'accuracy': 0.8350515365600586, 'precision': 0.75, 'recall': 0.30000001192092896, 'f1_score': 0.4285714328289032}
Epoch 97/100


100%|██████████| 60/60 [00:33<00:00,  1.81it/s]


Epoch [97/100], Loss: 0.9217
Accuracy at epoch 96: {'accuracy': 0.8247422575950623, 'precision': 0.5789473652839661, 'recall': 0.550000011920929, 'f1_score': 0.5641025900840759}
Epoch 98/100


100%|██████████| 60/60 [00:33<00:00,  1.80it/s]


Epoch [98/100], Loss: 0.3176
Accuracy at epoch 97: {'accuracy': 0.8659793734550476, 'precision': 0.6521739363670349, 'recall': 0.75, 'f1_score': 0.6976743936538696}
Epoch 99/100


100%|██████████| 60/60 [00:34<00:00,  1.74it/s]


Epoch [99/100], Loss: 0.2809
Accuracy at epoch 98: {'accuracy': 0.8247422575950623, 'precision': 0.5789473652839661, 'recall': 0.550000011920929, 'f1_score': 0.5641025900840759}
Epoch 100/100


100%|██████████| 60/60 [00:34<00:00,  1.75it/s]


Epoch [100/100], Loss: 1.7935
Accuracy at epoch 99: {'accuracy': 0.8556700944900513, 'precision': 0.6153846383094788, 'recall': 0.800000011920929, 'f1_score': 0.695652186870575}
