## Imports

In [1]:
import os
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import wandb
import numpy as np
from torchvision import transforms
from video_dataset import VideoFrameDataset, ImglistToTensor
from mmcv_resnet_slow import ResNet3dSlowOnly
from mmcv_tpn import TPN
from tpn_head import TPNHead

wandb.init(entity="cares", project="autoencoder-experiments", group="tpn-cls", name="wlasl")

[34m[1mwandb[0m: Currently logged in as: [33msttaseen[0m ([33mcares[0m). Use [1m`wandb login --relogin`[0m to force relogin


## Device Agnostic Code

In [2]:
try:
    device = 'mps' if torch.backends.mps.is_available() else 'cpu'
except:
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Data Setup

In [3]:
data_root = os.path.join(os.getcwd(), 'data/wlasl/rawframes') 
ann_file_train = os.path.join(os.getcwd(), 'data/wlasl/train_annotations.txt') 
ann_file_test = os.path.join(os.getcwd(), 'data/wlasl/test_annotations.txt')
work_dir = 'work_dirs/wlasl/tpncls/'
batch_size = 16



os.makedirs(work_dir, exist_ok=True)


# Setting up data augments
train_pipeline = transforms.Compose([
        ImglistToTensor(), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
        transforms.Resize((256, 256)), # image batch, resize smaller edge to 256
        transforms.RandomResizedCrop((248, 248), scale=(0.5, 1.0)), # image batch, center crop to square 224x224
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
    ])

test_pipeline = transforms.Compose([
        ImglistToTensor(), # list of PIL images to (FRAMES x CHANNELS x HEIGHT x WIDTH) tensor
        transforms.Resize((256, 256)),  # image batch, resize smaller edge to 256
        transforms.CenterCrop((224, 224)),  # image batch, center crop to square 224x224
        transforms.Normalize(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375]),
    ])

# Setting up datasets
train_dataset = VideoFrameDataset(
    root_path=data_root,
    annotationfile_path=ann_file_train,
    num_segments=12,
    frames_per_segment=1,
    imagefile_template='img_{:05d}.jpg',
    transform=train_pipeline,
    test_mode=False
)


test_dataset = VideoFrameDataset(
    root_path=data_root,
    annotationfile_path=ann_file_test,
    num_segments=12,
    frames_per_segment=1,
    imagefile_template='img_{:05d}.jpg',
    transform=test_pipeline,
    test_mode=True
)

# Setting up dataloaders
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=4,
                                    pin_memory=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                    batch_size=1,
                                    shuffle=True,
                                    num_workers=4,
                                    pin_memory=True)

# # Testing
# dataiter = iter(test_loader)
# get = next(dataiter)
# reshape = get[0].permute(0,2,1,3,4)
# video = iter(reshape[0][0])
# plt.imshow(next(video))

In [4]:
# plt.imshow(next(video))

## Set up TPN

In [5]:
class TPNEncoder(nn.Module):
    def __init__(self, backbone, neck):
        super(TPNEncoder, self).__init__()
        self.backbone = backbone
        self.neck = neck
        self.loss_aux = {}
        
    def get_loss_aux(self):
        return self.loss_aux['loss_aux']
        
    def forward(self, x, targets=None):
        code = self.backbone(x)
        code, loss_aux = self.neck(code, targets)
        self.loss_aux = loss_aux
        return code

## Set up Autoencoder

In [6]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        
    def forward(self, x, targets):
        code = self.encoder(x, targets)
        return self.decoder(code)

## Initialise model

In [7]:
# Create a TPN model
backbone = ResNet3dSlowOnly(
            depth=50,
            pretrained='torchvision://resnet50',
            lateral=False,
            out_indices=(2, 3),
            conv1_kernel=(1, 7, 7),
            conv1_stride_t=1,
            pool1_stride_t=1,
            inflate=(0, 0, 1, 1),
            norm_eval=False)

backbone.init_weights()

neck = TPN(
        in_channels=(1024, 2048),
        out_channels=1024,
        spatial_modulation_cfg=dict(
            in_channels=(1024, 2048), out_channels=2048),
        temporal_modulation_cfg=dict(downsample_scales=(8, 8)),
        upsample_cfg=dict(scale_factor=(1, 1, 1)),
        downsample_cfg=dict(downsample_scale=(1, 1, 1)),
        level_fusion_cfg=dict(
            in_channels=(1024, 1024),
            mid_channels=(1024, 1024),
            out_channels=2048,
            downsample_scales=((1, 1, 1), (1, 1, 1))),
        aux_head_cfg=dict(out_channels=400, loss_weight=0.5))

neck.init_weights()

encoder = TPNEncoder(backbone, neck)

decoder = TPNHead(
        num_classes=400,
        in_channels=2048,
        spatial_type='avg',
        consensus=dict(type='AvgConsensus', dim=1),
        dropout_ratio=0.5,
        init_std=0.01)

decoder.init_weights()

model = EncoderDecoder(encoder, decoder)

2023-01-30 12:08:32,750 - mmcv_resnet_slow - INFO - load model from: torchvision://resnet50
2023-01-30 12:08:32,815 - mmcv_resnet_slow - INFO - These parameters in the 2d checkpoint are not loaded: {'fc.bias', 'fc.weight'}


load checkpoint from torchvision path: torchvision://resnet50


## Set up loss. optimiser and scheduler

In [8]:
# Specify loss function
loss_fn = nn.CrossEntropyLoss()

# Specify loss function
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0001, nesterov=True)

# Specify learning rate scheduler
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[75,125])

# Setup wandb
wandb.watch(model, log_freq=10)

[]

In [9]:
# checkpoints = torch.load(work_dir+'latest.pth')
# model.load_state_dict(checkpoints)

## Testing model output

In [10]:
# reshape.shape

In [11]:
# code = encoder(reshape)

In [12]:
# encoder.loss_aux

In [13]:
# for i, (images, targets) in enumerate(train_loader):
#     break

In [14]:
# images, targets = images.permute(0,2,1,3,4), targets

In [15]:
# code = encoder(images, targets)

In [16]:
# encoder.get_loss_aux()

In [17]:
# model.encoder.get_loss_aux()

## Training

In [18]:
def top_k_accuracy(scores, labels, topk=(1, )):
    """Calculate top k accuracy score.
    Args:
        scores (list[np.ndarray]): Prediction scores for each class.
        labels (list[int]): Ground truth labels.
        topk (tuple[int]): K value for top_k_accuracy. Default: (1, ).
    Returns:
        list[float]: Top k accuracy score for each k.
    """
    res = np.zeros(len(topk))
    labels = np.array(labels)[:, np.newaxis]
    for i, k in enumerate(topk):
        max_k_preds = np.argsort(scores, axis=1)[:, -k:][:, ::-1]
        match_array = np.logical_or.reduce(max_k_preds == labels, axis=1)
        topk_acc_score = match_array.sum() / match_array.shape[0]
        res[i] = topk_acc_score

    return res

In [19]:
def train_one_epoch(epoch_index, interval=5):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, (images, targets) in enumerate(train_loader):
        # Every data instance is an input + label pair
        images, targets = images.to(device).permute(0,2,1,3,4), targets.to(device)

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(images, targets)

        # Compute the total loss with the encoder's auxiliary loss
        loss = loss_fn(outputs, targets) + model.encoder.get_loss_aux()
        loss.backward()
        
        # Gradient Clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=40, norm_type=2)

        # Adjust learning weights
        optimizer.step()
        scheduler.step()

        # Gather data and report
        running_loss += loss.item()
        if i % interval == interval-1:
            last_loss = running_loss / interval # loss per batch     
            print(f'Epoch [{epoch_index}][{i+1}/{len(train_loader)}], lr: {scheduler.get_last_lr()[0]:.5e}, loss: {last_loss:.5}')
            running_loss = 0.
    
    avg_loss = running_loss / (i + 1)
    return avg_loss

In [20]:
def validate():
    running_vloss = 0.0
    running_vacc = np.zeros(2)
    
    print('Evaluating top_k_accuracy...')
    
    with torch.inference_mode():   
        for i, (vimages, vtargets) in enumerate(test_loader):
            vimages, vtargets = vimages.to(device), vtargets.to(device)
            
            voutputs = model(vimages.permute(0,2,1,3,4), vtargets)
            
            vloss = loss_fn(voutputs, vtargets) + model.encoder.get_loss_aux()
            running_vloss += vloss

            running_vacc += top_k_accuracy(voutputs.detach().cpu().numpy(), vtargets.detach().cpu().numpy(), topk=(1,5))

    avg_vloss = running_vloss / (i + 1)

    acc = running_vacc/len(test_loader)
    top1_acc = acc[0].item()
    top5_acc = acc[1].item()
    
    return (avg_vloss, top1_acc, top5_acc)

In [21]:
epochs = 150

best_vloss = 1_000_000.

# Transfer model to device
model.to(device)

for epoch in range(epochs):
    
    # Turn on gradient tracking and do a forward pass
    model.train(True)
    avg_loss = train_one_epoch(epoch+1)
    
    # Turn off  gradients for reporting
    model.train(False)
    
    avg_vloss, top1_acc, top5_acc = validate()
    
    print(f'top1_acc: {top1_acc:.4}, top5_acc: {top5_acc:.4}, train_loss: {avg_loss:.5}, val_loss: {avg_vloss:.5}')
    
    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = work_dir + f'epoch_{epoch+1}'
        print(f'Saving checkpoint at {epoch+1} epochs...')
        torch.save(model.state_dict(), model_path)
        
    # Track wandb
    wandb.log({'loss': avg_loss,
             'val/loss': avg_vloss,
             'val/top1_accuracy': top1_acc,
             'val/top5_accuracy': top5_acc})



Epoch [1][5/112], lr: 1.00000e-02, loss: 8.9355
Epoch [1][10/112], lr: 1.00000e-02, loss: 8.5296
Epoch [1][15/112], lr: 1.00000e-02, loss: 8.4701
Epoch [1][20/112], lr: 1.00000e-02, loss: 7.8916
Epoch [1][25/112], lr: 1.00000e-02, loss: 7.935
Epoch [1][30/112], lr: 1.00000e-02, loss: 7.6815
Epoch [1][35/112], lr: 1.00000e-02, loss: 7.8408
Epoch [1][40/112], lr: 1.00000e-02, loss: 7.5524
Epoch [1][45/112], lr: 1.00000e-02, loss: 7.4407
Epoch [1][50/112], lr: 1.00000e-02, loss: 7.5729
Epoch [1][55/112], lr: 1.00000e-02, loss: 7.505
Epoch [1][60/112], lr: 1.00000e-02, loss: 7.5715
Epoch [1][65/112], lr: 1.00000e-02, loss: 7.3453
Epoch [1][70/112], lr: 1.00000e-02, loss: 7.5266
Epoch [1][75/112], lr: 1.00000e-03, loss: 7.3394
Epoch [1][80/112], lr: 1.00000e-03, loss: 7.3859
Epoch [1][85/112], lr: 1.00000e-03, loss: 7.2694
Epoch [1][90/112], lr: 1.00000e-03, loss: 7.2463
Epoch [1][95/112], lr: 1.00000e-03, loss: 7.2638
Epoch [1][100/112], lr: 1.00000e-03, loss: 7.2421
Epoch [1][105/112], lr

KeyboardInterrupt: 