In [None]:
import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

from torch.utils.data import Dataset, DataLoader, random_split
from utils.video_dataset import VideoDataset

from models.resnet26_3D import resnet26, resnet26b
from utils.augmentation import AugmentationGAN
from utils.pytorch_tools import gpu_usage
from utils.train_functions import train_one_epoch, eval_one_epoch

from tqdm import tqdm
import numpy as np
import logging
from pathlib import Path

In [None]:
# Use GPU in case it's available 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

gpu_usage()

In [None]:
batch_size = 1
slice_length = 10
stride = 5
width, height = input_shape = (224, 224)

In [None]:
dataset = VideoDataset(root_dir='data/data_videos_PART2',
                       width=width, height=height,
                       slice_length=10, stride=5)

ratios = [0.98,0.01,0.01]

total = len(dataset)
lengths = [int(r * total) for r in ratios]
lengths[-1] = total - sum(lengths[:-1])

print('Total data:', total, 'and the splits are:', lengths, 'train, val, test')

train, val, test = random_split(dataset, lengths=lengths)

train_loader = DataLoader(train, batch_size=batch_size, collate_fn=dataset.collate, num_workers=1, pin_memory=True)
val_loader = DataLoader(val, batch_size=batch_size, collate_fn=dataset.collate, num_workers=1, pin_memory=True)
test_loader = DataLoader(test, batch_size=batch_size, collate_fn=dataset.collate, num_workers=1, pin_memory=True)

In [None]:
# # Get train data
# trainset = VideoDataset('data/data_videos_PART2_separated/train', height = 224, width = 224)
# train_loader = DataLoader(trainset, batch_size = batch_size, shuffle = False, pin_memory = True)

# # Get validation data
# validationset = VideoDataset('data/data_videos_PART2_separated/validation', height = 224, width = 224)
# validation_loader = DataLoader(validationset, batch_size = batch_size, shuffle = False, pin_memory = True)

# # Get test data
# testset = VideoDataset('data/data_videos_PART2_separated/test', height = 224, width = 224)
# test_loader = DataLoader(testset, batch_size = batch_size, shuffle = False, pin_memory = True)

In [None]:
num_classes = 2
sample_size = 1
sample_duration = 1

model = resnet26b(sample_size=sample_size,
                sample_duration=sample_duration,
                input_shape=input_shape,
                num_classes=num_classes,
                last_fc=True)

model = model.to(device)

gpu_usage()

In [None]:
# Training the model 
num_epochs = 10 
learning_rate = 0.0001

criterion = CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)

Path('checkpoint').mkdir(exist_ok=True)

val_acc = 0
for epoch in range(1, num_epochs+1):  # Loop over the dataset multiple times

    train_results = train_one_epoch(model, train_loader, optimizer, criterion, epoch, device)
    val_results   = eval_one_epoch(model, val_loader, criterion, epoch, device)

    results = { **train_results, **val_results} # combine train and val results

    if results.get('val_acc') > val_acc: # save checkpoint if val_acc is better
        torch.save({'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    **results}, f'checkpoint/checkpoint_{epoch}.pt')

print('Finished Training')

In [None]:
# loading last checkpoint

last_checkpoint = sorted([f'checkpoint/checkpoint_{epoch}.pt' for epoch in range(10)], reverse=True)[0]
checkpoint = torch.load(last_checkpoint)

model.load_state_dict(checkpoint['model_state_dict'])

In [None]:
# Evaluating the model on the testset 
correct = 0
total = 0

with torch.no_grad():
    # We are going to evaluate the model on the testset only ! 
    for data in testloader :
        # Load inputs and labels
        video, labels = data
        video, labels = video.to(device), labels.to(device)

        outputs = model(video)
        # Get the indexes of maximum values along the second axis
        _, predicted = torch.max(outputs, dim=1)
        total += labels.size(0)
        # Add the number of correct predictions for the batch to the total count
        correct += (predicted == labels).sum().item()

print(f"Test acccuracy: {(100 * correct / total)}%")

In [None]:
# Loading the model once it's trained