# CSN-Pytorch

## Imports

In [13]:
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
# import wandb
from csn import csn50
from torchvision import transforms
from video_dataset import VideoFrameDataset, ImglistToTensor

# wandb.init(entity="cares", project="autoencoder-experiments", group="pytorch-csn")

## Device Agnostic Code

In [14]:
try:
    device = 'mps' if torch.backends.mps.is_available() else 'cpu'
except:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Data Setup

In [3]:
data_root = os.path.join(os.getcwd(), 'data/autsl/rawframes') 
ann_file_train = os.path.join(os.getcwd(), 'data/autsl/train_annotations.txt') 
ann_file_test = os.path.join(os.getcwd(), 'data/autsl/test_annotations.txt')
batch_size = 20

# Setting up data augments
train_pipeline = transforms.Compose([
        ImglistToTensor(), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
        transforms.Resize(256), # image batch, resize smaller edge to 256
        transforms.RandomResizedCrop((224, 224)), # image batch, center crop to square 224x224
        transforms.RandomHorizontalFlip(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

test_pipeline = transforms.Compose([
        ImglistToTensor(), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
        transforms.Resize(256),  # image batch, resize smaller edge to 256
        transforms.CenterCrop((224, 224)),  # image batch, center crop to square 224x224
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

# Setting up datasets
train_dataset = VideoFrameDataset(
    root_path=data_root,
    annotationfile_path=ann_file_train,
    num_segments=5,
    frames_per_segment=1,
    imagefile_template='img_{:05d}.jpg',
    transform=train_pipeline,
    test_mode=False
)


test_dataset = VideoFrameDataset(
    root_path=data_root,
    annotationfile_path=ann_file_test,
    num_segments=5,
    frames_per_segment=1,
    imagefile_template='img_{:05d}.jpg',
    transform=test_pipeline,
    test_mode=True
)

# Setting up dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=4,
                                    pin_memory=False)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                    batch_size=2,
                                    shuffle=False,
                                    num_workers=4,
                                    pin_memory=False)

In [4]:
dataiter = iter(train_loader)
get = next(dataiter)
get[0].shape

torch.Size([20, 5, 3, 224, 224])

In [5]:
get = next(dataiter)
get[0].shape

torch.Size([20, 5, 3, 224, 224])

In [6]:
# Testing dataloader
dataiter = iter(train_loader)
get = next(dataiter)
get[0].size(0)

20

In [7]:
reshape = get[0].permute(0,2,1,3,4)
reshape.shape

torch.Size([20, 3, 5, 224, 224])

In [9]:
values = csn(reshape)

hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
hello
layer4: torch.Size([20, 2048, 1, 7, 7])
avg_pool: torch.Size([20, 2048, 1, 1, 1])
view: torch.Size([20, 2048])
fc: torch.Size([20, 400])


In [16]:
csn.state_dict

<bound method Module.state_dict of CSN(
  (conv1): Conv3d(3, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2), padding=(1, 3, 3), bias=False)
  (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (max_pool): MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): CSNBottleneck(
      (conv1): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
      (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Sequential(
        (0): Conv3d(64, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1), bias=False)
        (1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), groups=64, bias=False)
      )
      (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv3d(64, 256, kernel_size=(1, 1, 1), s

In [17]:
torch.save(csn.state_dict(), 'csn_pytorch.pth')

## Set up model, loss, optimiser and others

In [15]:
# Create a CSN model
csn = csn50(num_classes=400)

# Specify loss function
loss_fn = nn.CrossEntropyLoss()

# Specify loss function
optimizer = torch.optim.SGD(csn.parameters(), lr=0.000125, momentum=0.9, weight_decay=0.0001)

# Specify learning rate scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[12,16])

# Setup wandb
# wandb.watch(csn, log_freq=100)

## Train model

In [49]:
epochs = 1
interval = 1 # For checkpoints and validation

losses = []
eval_accu = []
eval_losses=[]

csn.to(device)
csn.train()

for epoch in range(epochs):
    # Reset train loss
    train_loss = 0.0
    for video_batch, targets in train_loader:
        # Move data to device
        video_batch, targets = video_batch.to(device), targets.to(device)
        
        # batch_size, n_frames, channels, h, w -> previous doesn't work
        # batch_size, channels, n_frames, h, w -> current
        predictions = csn(video_batch.permute(0,2,1,3,4))
        
        # Calculate loss
        loss = loss_fn(predictions, targets)
        
        # Backpropagate
        optimizer.zero_grad()
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        
        # Calculate batch loss
        train_loss += loss.item()

    
    # Check for interval to validate and save checkpoints
    if interval!=0 and epoch%interval==0:
        running_loss=0
        correct=0
        total=0
        with torch.no_grad():
            for video_batch, targets in test_loader:
                video_batch, targets = video_batch.to(device), targets.to(device)

                predictions = csn(video_batch.permute(0,2,1,3,4))
                x = predictions
                
                loss = loss_fn(predictions, targets)
                running_loss += loss.item()

                _, predicted = predictions.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        test_loss=running_loss/len(test_loader)
        accu=100.*correct/total

        eval_losses.append(test_loss)
        eval_accu.append(accu)
        print('Test Loss: %.3f | Accuracy: %.3f'%(test_loss,accu)) 

    train_loss = train_loss/len(train_loader)
    losses.append(train_loss)
    print(f'Epoch: {epoch+1} \tTraining Loss: {train_loss:.6f}')
    
    # Wandb Log
#     wandb.log({"loss": loss,
#               "val/top_1_acc": accu})

KeyboardInterrupt: 

In [None]:
plt.plot(losses)
plt.title('Loss Graph')
plt.ylabel('Loss')
plt.xlabel('Epoch')