# CSN-Pytorch

## Imports

In [61]:
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from csn import csn152
from torchvision import transforms
from video_dataset import VideoFrameDataset, ImglistToTensor

## Device Agnostic Code

In [4]:
try:
    device = 'mps' if torch.backends.mps.is_available() else 'cpu'
except:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Data Setup

In [87]:
data_root = os.path.join(os.getcwd(), 'data/wlasl/rawframes') 
ann_file_train = os.path.join(os.getcwd(), 'data/wlasl/train_annotations.txt') 
ann_file_test = os.path.join(os.getcwd(), 'data/wlasl/test_annotations.txt')
batch_size = 2

# Setting up data transforms
train_pipeline = transforms.Compose([
        ImglistToTensor(), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
        transforms.Resize(256), # image batch, resize smaller edge to 256
        transforms.RandomResizedCrop((224, 224)), # image batch, center crop to square 224x224
        transforms.RandomHorizontalFlip(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

test_pipeline = transforms.Compose([
        ImglistToTensor(), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
        transforms.Resize(256),  # image batch, resize smaller edge to 256
        transforms.CenterCrop((224, 224)),  # image batch, center crop to square 224x224
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

# Setting up datasets
train_dataset = VideoFrameDataset(
    root_path=data_root,
    annotationfile_path=ann_file_train,
    num_segments=5,
    frames_per_segment=1,
    imagefile_template='img_{:05d}.jpg',
    transform=train_pipeline,
    test_mode=False
)


test_dataset = VideoFrameDataset(
    root_path=data_root,
    annotationfile_path=ann_file_test,
    num_segments=5,
    frames_per_segment=1,
    imagefile_template='img_{:05d}.jpg',
    transform=test_pipeline,
    test_mode=True
)

# Setting up dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=4,
                                    pin_memory=False)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                    batch_size=2,
                                    shuffle=False,
                                    num_workers=4,
                                    pin_memory=False)

In [89]:
# Testing dataloader
dataiter = iter(train_loader)
get = next(dataiter)
get[0].shape

torch.Size([2, 5, 3, 224, 224])

## Set up model, loss and optimiser

In [64]:
# Create a CSN model
csn = csn152(num_classes=400)

# Specify loss function
loss_fn = nn.CrossEntropyLoss()

# Specify loss function
optimizer = torch.optim.Adam(csn.parameters(), lr=0.001)

## Train model

In [65]:
epochs = 5
interval = 10 # For checkpoints and validation

losses = []
eval_accu = []
eval_losses=[]

csn.to(device)
csn.train()

for epoch in range(epochs):
    # Reset train loss
    train_loss = 0.0
    for video_batch, targets in train_loader:
        # Move data to device
        video_batch, targets = video_batch.to(device), targets.to(device)
        
        # batch_size, channels, n_frames, h, w
        predictions = csn(video_batch.view(video_batch.size(0), 3, 5, 224, 224))
        
        # Calculate loss
        loss = loss_fn(predictions, targets)
        
        # Backpropagate
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Calculate batch loss
        train_loss += loss.item()*video_batch.size(0)

    # Check for interval to validate and save checkpoints
    if epoch%interval==0:
        running_loss=0
        correct=0
        total=0
        with torch.no_grad():
            for video_batch, targets in test_loader:
                video_batch, targets = video_batch.to(device), targets.to(device)

                predictions = csn(video_batch.view(video_batch.size(0), 3, 5, 224, 224))

                loss = loss_fn(predictions, targets)
                running_loss += loss.item()

                _, predicted = predictions.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        test_loss=running_loss/len(test_loader)
        accu=100.*correct/total

        eval_losses.append(test_loss)
        eval_accu.append(accu)
        print('Test Loss: %.3f | Accuracy: %.3f'%(test_loss,accu)) 

    train_loss = train_loss/len(train_loader)
    losses.append(train_loss)
    print(f'Epoch: {epoch+1} \tTraining Loss: {train_loss:.6f}')

Test Loss: 2.644 | Accuracy: 0.000
Epoch: 1 	Training Loss: 12.320577
Epoch: 2 	Training Loss: 6.218439
Epoch: 3 	Training Loss: 5.693654
Epoch: 4 	Training Loss: 5.565608
Epoch: 5 	Training Loss: 5.368568


In [68]:
import numpy as np

In [75]:
3*5*2*2


60

In [78]:
x = np.arange(0, 60)
x

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
       51, 52, 53, 54, 55, 56, 57, 58, 59])

In [83]:
video = x.reshape(5, 3, 2, 2)
video

array([[[[ 0,  1],
         [ 2,  3]],

        [[ 4,  5],
         [ 6,  7]],

        [[ 8,  9],
         [10, 11]]],


       [[[12, 13],
         [14, 15]],

        [[16, 17],
         [18, 19]],

        [[20, 21],
         [22, 23]]],


       [[[24, 25],
         [26, 27]],

        [[28, 29],
         [30, 31]],

        [[32, 33],
         [34, 35]]],


       [[[36, 37],
         [38, 39]],

        [[40, 41],
         [42, 43]],

        [[44, 45],
         [46, 47]]],


       [[[48, 49],
         [50, 51]],

        [[52, 53],
         [54, 55]],

        [[56, 57],
         [58, 59]]]])

In [85]:
reshape = video.reshape(3, 5, 2, 2)
reshape

array([[[[ 0,  1],
         [ 2,  3]],

        [[ 4,  5],
         [ 6,  7]],

        [[ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15]],

        [[16, 17],
         [18, 19]]],


       [[[20, 21],
         [22, 23]],

        [[24, 25],
         [26, 27]],

        [[28, 29],
         [30, 31]],

        [[32, 33],
         [34, 35]],

        [[36, 37],
         [38, 39]]],


       [[[40, 41],
         [42, 43]],

        [[44, 45],
         [46, 47]],

        [[48, 49],
         [50, 51]],

        [[52, 53],
         [54, 55]],

        [[56, 57],
         [58, 59]]]])

In [None]:
# first image rgb
array([[[[ 0,  1],
         [ 2,  3]],

        [[ 4,  5],
         [ 6,  7]],

        [[ 8,  9],
         [10, 11]]],