# Re implementing ViTTA (light version)

We will implement ViTTA using PyTorch. We will evaluate it on on two model architectures: TANet based on ResNet50, and Video Swin Transformer based on Swin-B.

### Pipeline:

<img src="https://wlin-at.github.io/media/vitta/pipeline.png">

You can refer to the [ViTTA paper](https://openaccess.thecvf.com/content/CVPR2023/papers/Lin_Video_Test-Time_Adaptation_for_Action_Recognition_CVPR_2023_paper.pdf) for more information.

In [6]:
import torchvision
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import UCF101
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
import cv2
import av

## Data Loading

Let's first load UCF101 dataset, and try to visualize it using matplotlib and opencv:

In [None]:
transform = Compose([
    Resize((112,112)),
    CenterCrop((112,112)),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

data_path = 'data/UCF101/'

ucf101_train = UCF101(
    root=data_path + 'UCF101',
    annotation_path=data_path + 'ucfTrainTestlist',
    frames_per_clip=5,
    step_between_clips=1,
    train=True,
    transform=transform)

ucf101_val = UCF101(
    root=data_path + 'UCF101',
    annotation_path=data_path + 'ucfTrainTestlist',
    frames_per_clip=5,
    step_between_clips=1,
    train=False,
    transform=transform)

In [None]:
ucf101_train

Let's create the dataset:

In [13]:
class VideoDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        video, _, label = self.dataset[idx]
        video = video.permute(3,0,1,2)
        return video, label
    
ucf101_train_dataset = VideoDataset(ucf101_train)
ucf101_val_dataset = VideoDataset(ucf101_val)

Now let's create the DataLoader:

In [None]:
train_loader = DataLoader(
    ucf101_train_dataset,
    batch_size=1,
    shuffle=True
)

val_loader = DataLoader(
    ucf101_val_dataset,
    batch_size=8,
    shuffle=True
)