## Load 3D CNN Model

In [1]:
import torch
from torch import nn
import torchvision

#This downloads the model from torchvision
model = torchvision.models.video.r3d_18(pretrained=True, progress=True)

print ('Total number of parameters: {}'.format(sum(p.numel() for p in model.parameters())))

Total number of parameters: 33371472


## Modify the Last Layer

In [2]:
import torch.optim as optim
model.fc = nn.Linear(in_features=512, out_features=10, bias=True)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

## Set the data

In [3]:
# Generating random data for testing purposes
torch.manual_seed(7)
trainData = torch.rand(10, 3, 2, 4, 4)*255
trainLabels = torch.randint(0, 10, (10,))

testData = torch.rand(10, 3, 2, 4, 4)*255
testLabels = torch.randint(0, 10, (10,))

# trainset_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)
# valset_loader = DataLoader(valset, batch_size=64, shuffle=True, num_workers=1)
# testset_loader = DataLoader(testset, batch_size=64, shuffle=True, num_workers=1)

In [4]:
import torch.nn.functional as F

def train_model(x, y, model, epochs=5):
    model.train()
    for t in range(epochs):
        output = model(x)

        loss = F.nll_loss(output, y)

        # clear the gradients of all tensors being optimized.
        optimizer.zero_grad()
        
        loss.backward()
        
        # parameter update
        optimizer.step()

def check_accuracy(x, y, model):
    num_correct = 0
    num_samples = 0
    test_loss = 0
    model.eval()  # set model to evaluation mode
    with torch.no_grad():
        scores = model(x)
        test_loss += F.nll_loss(scores, y, size_average=False).item() # sum up batch loss
        _, preds = scores.max(1)
        num_correct += (preds == y).sum()
        num_samples += preds.size(0)
        acc = float(num_correct) / num_samples
        print('Got %d / %d correct (%.2f), Loss: (%.2f)' % 
              (num_correct, num_samples, 100 * acc, float(test_loss) / num_samples))

## Train (Finetune) the model

In [5]:
train_model(trainData, trainLabels, model)

## Test the model

In [6]:
check_accuracy(testData, testLabels, model)

Got 1 / 10 correct (10.00), Loss: (-0.04)


