In [1]:
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import os
from PIL import Image
import torch
from sklearn.model_selection import train_test_split
import pandas as pd
from torchvision.transforms import transforms, Compose, ToTensor, Resize, Normalize, CenterCrop, Grayscale
from torch import nn
from tqdm import tqdm
from torchinfo import summary
import numpy as np
import math
from torchvision.models.video import r3d_18, R3D_18_Weights, mc3_18, MC3_18_Weights

In [2]:
class SnakeDataSet(Dataset):
    def __init__(self, dataframe, root_dir, stack_size, transform = None):
        self.stack_size = stack_size
        self.key_frame = dataframe
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.key_frame) - self.stack_size *3

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.to_list()
        try:
            img_names = [os.path.join(self.root_dir, self.key_frame.iloc[idx + i, 0]) for i in range(self.stack_size)]
            images = [Image.open(img_name) for img_name in img_names]
            label = torch.tensor(self.key_frame.iloc[idx + self.stack_size, 1])
            if self.transform:
                images = [self.transform(image) for image in images]
        except:
            img_names = [os.path.join(self.root_dir, self.key_frame.iloc[0 + i, 0]) for i in range(self.stack_size)]
            images = [Image.open(img_name) for img_name in img_names]
            label = torch.tensor(self.key_frame.iloc[0 + self.stack_size, 1])
            if self.transform:
                images = [self.transform(image) for image in images]
        return torch.stack(images,dim = 1).squeeze(), label

In [3]:
transformer = Compose([
    Resize((84,84), antialias=True),
    CenterCrop(84),
    ToTensor(),
    Normalize(mean =[ 0.8725,  1.8742, -0.2931], std =[0.3376, 0.3561, 0.3825] )
])

In [4]:
STACK_SIZE = 4

train, test = train_test_split(pd.read_csv("data/labels_snake.csv"), test_size=0.2, shuffle=False)
classes = ["n", "left", "up", "right", "down"]

labels_unique, counts = np.unique(train["class"], return_counts=True)
class_weights = [sum(counts)/c for c in counts]
example_weights = np.array([class_weights[l] for l in train['class']])
example_weights = np.roll(example_weights, -STACK_SIZE)
sampler = WeightedRandomSampler(example_weights, len(train))

labels_unique, counts = np.unique(test["class"], return_counts=True)
class_weights = [sum(counts)/c for c in counts]
test_example_weights = np.array([class_weights[l] for l in test['class']])
test_example_weights = np.roll(test_example_weights, -STACK_SIZE)
test_sampler = WeightedRandomSampler(test_example_weights, len(test))

In [5]:
from torch.utils.data import DataLoader
BATCH_SIZE = 32
dataset = SnakeDataSet(root_dir="captures", dataframe = train, stack_size=STACK_SIZE, transform=transformer)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, drop_last= True)
test_dataset = SnakeDataSet(root_dir="captures", dataframe = test, stack_size=STACK_SIZE,  transform=transformer)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, sampler = test_sampler, drop_last=True)


In [6]:
def compute_mean_std(dataloader):
    '''
    We assume that the images of the dataloader have the same height and width
    source: https://github.com/aladdinpersson/Machine-Learning-Collection/blob/master/ML/Pytorch/Basics/pytorch_std_mean.py
    '''
    # var[X] = E[X**2] - E[X]**2
    channels_sum, channels_sqrd_sum, num_batches = 0, 0, 0

    for batch_images, labels in tqdm(dataloader):  # (B,H,W,C)
        batch_images = batch_images.permute(0,3,4,2,1)
        channels_sum += torch.mean(batch_images, dim=[0, 1, 2, 3])
        channels_sqrd_sum += torch.mean(batch_images ** 2, dim=[0, 1, 2,3])
        num_batches += 1

    mean = channels_sum / num_batches
    std = (channels_sqrd_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

In [7]:
# compute_mean_std(dataloader)

In [8]:
images, label = next(iter(dataloader))
# print(label)
# from matplotlib.pyplot import subplots
# fig,ax = subplots(1, STACK_SIZE, constrained_layout = True, figsize=(15,5))
# for i in range(STACK_SIZE):
#     ax[i].imshow(images[2][i,:,:],cmap="gray");
#     ax[i].axis('off')
# fig.supylabel(classes[label[2].item()]);

In [9]:
images.shape

torch.Size([32, 3, 4, 84, 84])

In [10]:
# model = mc3_18(weights = MC3_18_Weights.DEFAULT)
model = r3d_18(weights = R3D_18_Weights.DEFAULT)
# for param in model.parameters():
#     param.requires_grad = False
model.fc = nn.Linear(in_features=512, out_features=5, bias=True)
summary(model, (32,3,4,84,84))

Layer (type:depth-idx)                   Output Shape              Param #
VideoResNet                              [32, 5]                   --
├─BasicStem: 1-1                         [32, 64, 4, 42, 42]       --
│    └─Conv3d: 2-1                       [32, 64, 4, 42, 42]       28,224
│    └─BatchNorm3d: 2-2                  [32, 64, 4, 42, 42]       128
│    └─ReLU: 2-3                         [32, 64, 4, 42, 42]       --
├─Sequential: 1-2                        [32, 64, 4, 42, 42]       --
│    └─BasicBlock: 2-4                   [32, 64, 4, 42, 42]       --
│    │    └─Sequential: 3-1              [32, 64, 4, 42, 42]       110,720
│    │    └─Sequential: 3-2              [32, 64, 4, 42, 42]       110,720
│    │    └─ReLU: 3-3                    [32, 64, 4, 42, 42]       --
│    └─BasicBlock: 2-5                   [32, 64, 4, 42, 42]       --
│    │    └─Sequential: 3-4              [32, 64, 4, 42, 42]       110,720
│    │    └─Sequential: 3-5              [32, 64, 4, 42, 42]     

In [11]:
num_epochs = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.AdamW(model.parameters(), 10e-5, weight_decay=0.1)
model.to(device)
criterion = nn.CrossEntropyLoss()

In [12]:
for epoch in range(num_epochs):
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    val_loss = 0.0
    val_correct_predictions = 0
    val_total_samples = 0

    # Set model to training mode
    model.train()

    # tqdm bar for progress visualization
    pbar = tqdm(dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=True)
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = model(inputs.to(device))
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update statistics
        total_loss += loss.item()
        _, predicted = torch.max(torch.softmax(outputs,1), 1)
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

        # Update tqdm bar with current loss and accuracy
        pbar.set_postfix({'Loss': total_loss / total_samples, 'Accuracy': correct_predictions / total_samples})
        steps = steps + 1

    model.eval()
    with torch.inference_mode():
        for inputs, labels in test_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs.to(device))
            loss = criterion(outputs, labels)

            # Update statistics
            val_loss += loss.item()
            _, predicted = torch.max(torch.softmax(outputs,1), 1)
            val_correct_predictions += (predicted == labels).sum().item()
            val_total_samples += labels.size(0)

    # Calculate and print epoch-level accuracy and loss for validation
    epoch_loss = val_loss / val_total_samples
    epoch_accuracy = val_correct_predictions / val_total_samples
    print(f'Epoch {epoch + 1}/{num_epochs}, Val Loss: {epoch_loss:.4f}, Val Accuracy: {epoch_accuracy:.4f}')
    torch.save(model.state_dict(), "model_r3d.pth")
    

Epoch 1/1: 100%|██████████| 3905/3905 [39:05<00:00,  1.67it/s, Loss=0.0099, Accuracy=0.892] 


Epoch 1/1, Val Loss: 0.0247, Val Accuracy: 0.7776


In [13]:
torch.save(model.state_dict(), "model_r3d.pth")

In [14]:
model = mc3_18(weights = MC3_18_Weights.DEFAULT)
model.fc = nn.Linear(in_features=512, out_features=5, bias=True)

num_epochs = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = torch.optim.AdamW(model.parameters(), 10e-5, weight_decay=0.1)
model.to(device)
criterion = nn.CrossEntropyLoss()

In [15]:
for epoch in range(num_epochs):
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    val_loss = 0.0
    val_correct_predictions = 0
    val_total_samples = 0

    # Set model to training mode
    model.train()

    # tqdm bar for progress visualization
    pbar = tqdm(dataloader, desc=f'Epoch {epoch + 1}/{num_epochs}', leave=True)
    for inputs, labels in pbar:
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = model(inputs.to(device))
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update statistics
        total_loss += loss.item()
        _, predicted = torch.max(torch.softmax(outputs,1), 1)
        correct_predictions += (predicted == labels).sum().item()
        total_samples += labels.size(0)

        # Update tqdm bar with current loss and accuracy
        pbar.set_postfix({'Loss': total_loss / total_samples, 'Accuracy': correct_predictions / total_samples})
        steps = steps + 1

    model.eval()
    with torch.inference_mode():
        for inputs, labels in test_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs.to(device))
            loss = criterion(outputs, labels)

            # Update statistics
            val_loss += loss.item()
            _, predicted = torch.max(torch.softmax(outputs,1), 1)
            val_correct_predictions += (predicted == labels).sum().item()
            val_total_samples += labels.size(0)

    # Calculate and print epoch-level accuracy and loss for validation
    epoch_loss = val_loss / val_total_samples
    epoch_accuracy = val_correct_predictions / val_total_samples
    print(f'Epoch {epoch + 1}/{num_epochs}, Val Loss: {epoch_loss:.4f}, Val Accuracy: {epoch_accuracy:.4f}')
    torch.save(model.state_dict(), "model_mc3.pth")

Epoch 1/1: 100%|██████████| 3905/3905 [35:54<00:00,  1.81it/s, Loss=0.00966, Accuracy=0.895]


Epoch 1/1, Val Loss: 0.0241, Val Accuracy: 0.7747


: 