# PyTorch 3D ResNet

In [42]:
import torch
from tqdm import tqdm
from time import time
from sklearn.metrics import accuracy_score, r2_score, precision_score, recall_score, f1_score
import pandas as pd

from torch.utils.data import DataLoader, Subset, random_split

from models.ThreeDResNet import get_3dResNet, get_resnet_transformer
from colorVideoDataset import ColorVideoDataset

In [2]:
MODEL = get_3dResNet()
DATASET = ColorVideoDataset('./colors')

Using cache found in C:\Users\Arnav Waghdhare/.cache\torch\hub\facebookresearch_pytorchvideo_main


Net(
  (blocks): ModuleList(
    (0): ResNetBasicStem(
      (conv): Conv3d(3, 64, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False)
      (norm): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): ReLU()
      (pool): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=[0, 1, 1], dilation=1, ceil_mode=False)
    )
    (1): ResStage(
      (res_blocks): ModuleList(
        (0): ResBlock(
          (branch1_conv): Conv3d(64, 256, kernel_size=(1, 1, 1), stride=(np.int64(1), np.int64(1), np.int64(1)), bias=False)
          (branch1_norm): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (branch2): BottleneckBlock(
            (conv_a): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
            (norm_a): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (act_a): ReLU()
            (conv_b): Conv3d(64, 64, kern

Train, Test, Val Sets

In [None]:
train_dataset, test_dataset = random_split(DATASET, [int(0.8 * len(DATASET)), len(DATASET) - int(0.8 * len(DATASET))])
test_dataset, val_dataset = random_split(test_dataset, [int(0.5 * len(test_dataset)), len(test_dataset) - int(0.5 * len(test_dataset))])

In [28]:
def get_dataloader(dataset, subset_ratio : float | None = 0.1, batch_size : int = 2):
    transform = get_resnet_transformer()

    def collate_fn(batch):
        videos = []
        labels = []
        for video, label, _ in batch:
            video = video.permute(1, 0, 2, 3)
            video = transform({"video": video})["video"]
            videos.append(video)
            labels.append(torch.tensor(label, dtype=torch.long))
        
        videos = torch.stack(videos)
        labels = torch.stack(labels)
        return videos, labels
    
    if subset_ratio is not None:
        num_samples = int(len(dataset) * subset_ratio) 
        subset_indices = list(range(num_samples))
        subset = Subset(dataset, subset_indices)
        dataloader = DataLoader(subset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    else:
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    return dataloader

In [None]:
def training_loop(model, dataloader, epochs=5, learning_rate=1e-4):
    results_dict = {
        'time_per_batch': [],
        'time_per_epoch': [],
        'accuracy': [],
        'precision': [],
        'recall': [],
        'f1': [],
        'r2': [],
        'loss': []
    }
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    model.train()
    for epoch in tqdm(range(epochs), desc="Training Epochs"):
        running_loss = 0.0
        epoch_start = time()
        all_preds = []
        all_labels = []
        
        for videos, labels in dataloader:
            batch_start = time()
            videos, labels = videos.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(videos)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            batch_time = time() - batch_start
            results_dict['time_per_batch'].append(batch_time)
            
            # Store predictions and labels for metrics
            all_preds.extend(outputs.argmax(dim=1).cpu().detach().numpy())
            all_labels.extend(labels.cpu().detach().numpy())
        
        epoch_time = time() - epoch_start
        results_dict['time_per_epoch'].append(epoch_time)
        
        # Calculate metrics
        results_dict['accuracy'].append(accuracy_score(all_labels, all_preds))
        results_dict['precision'].append(precision_score(all_labels, all_preds, average='weighted', zero_division=0))
        results_dict['recall'].append(recall_score(all_labels, all_preds, average='weighted', zero_division=0))
        results_dict['f1'].append(f1_score(all_labels, all_preds, average='weighted', zero_division=0))
        results_dict['r2'].append(r2_score(all_labels, all_preds))
        
        avg_loss = running_loss / len(dataloader)
        results_dict['loss'].append(avg_loss)
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}")
    
    return results_dict

In [34]:
dataloader = get_dataloader(DATASET, subset_ratio=0.1, batch_size=16)
x, y = next(iter(dataloader))
print(x.shape, y.shape)

torch.Size([16, 3, 8, 256, 256]) torch.Size([16])


In [40]:
results = training_loop(MODEL, dataloader, epochs=10, learning_rate=2e-4)
print(results)

Training Epochs:  10%|█         | 1/10 [00:04<00:41,  4.57s/it]

Epoch [1/10], Loss: 0.9398


Training Epochs:  20%|██        | 2/10 [00:08<00:34,  4.31s/it]

Epoch [2/10], Loss: 0.7308


Training Epochs:  30%|███       | 3/10 [00:12<00:29,  4.23s/it]

Epoch [3/10], Loss: 0.5720


Training Epochs:  40%|████      | 4/10 [00:17<00:25,  4.23s/it]

Epoch [4/10], Loss: 0.4330


Training Epochs:  50%|█████     | 5/10 [00:21<00:21,  4.21s/it]

Epoch [5/10], Loss: 0.3305


Training Epochs:  60%|██████    | 6/10 [00:25<00:16,  4.20s/it]

Epoch [6/10], Loss: 0.2482


Training Epochs:  70%|███████   | 7/10 [00:29<00:12,  4.21s/it]

Epoch [7/10], Loss: 0.1903


Training Epochs:  80%|████████  | 8/10 [00:33<00:08,  4.25s/it]

Epoch [8/10], Loss: 0.1524


Training Epochs:  90%|█████████ | 9/10 [00:38<00:04,  4.24s/it]

Epoch [9/10], Loss: 0.1173


Training Epochs: 100%|██████████| 10/10 [00:42<00:00,  4.23s/it]

Epoch [10/10], Loss: 0.0981
{'time_per_batch': [0.7560360431671143, 0.43861937522888184, 0.39720797538757324, 0.3898773193359375, 0.3858981132507324, 0.39762234687805176, 0.45095157623291016, 0.3962078094482422, 0.3863823413848877, 0.4003894329071045, 0.40998196601867676, 0.4006505012512207, 0.3985769748687744, 0.48401641845703125, 0.4817013740539551, 0.40950942039489746, 0.39815568923950195, 0.4507427215576172, 0.4028947353363037, 0.3863494396209717], 'time_per_epoch': [4.5674028396606445, 4.117170810699463, 4.126593828201294, 4.222600698471069, 4.170922517776489, 4.164057493209839, 4.2260901927948, 4.3325722217559814, 4.21908164024353, 4.1101765632629395], 'accuracy': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'precision': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'recall': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'f1': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 'r2': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]}





In [None]:
train_dataloader =  get_dataloader(train_dataset, subset_ratio=None, batch_size=4)
val_dataloader =  get_dataloader(val_dataset, subset_ratio=None, batch_size=4)
test_dataloader =  get_dataloader(test_dataset, subset_ratio=None, batch_size=4)

In [45]:
results

{'time_per_batch': [0.7560360431671143,
  0.43861937522888184,
  0.39720797538757324,
  0.3898773193359375,
  0.3858981132507324,
  0.39762234687805176,
  0.45095157623291016,
  0.3962078094482422,
  0.3863823413848877,
  0.4003894329071045,
  0.40998196601867676,
  0.4006505012512207,
  0.3985769748687744,
  0.48401641845703125,
  0.4817013740539551,
  0.40950942039489746,
  0.39815568923950195,
  0.4507427215576172,
  0.4028947353363037,
  0.3863494396209717],
 'time_per_epoch': [4.5674028396606445,
  4.117170810699463,
  4.126593828201294,
  4.222600698471069,
  4.170922517776489,
  4.164057493209839,
  4.2260901927948,
  4.3325722217559814,
  4.21908164024353,
  4.1101765632629395],
 'accuracy': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
 'precision': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
 'recall': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
 'f1': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
 'r2': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0