In [1]:
import torch
import clip
import pandas as pd
from torchvision.transforms import transforms  as T
from torchvision.transforms._transforms_video import ToTensorVideo
from pytorchvideo.transforms import Normalize, Permute, RandAugment
from torch import Tensor, nn, optim
from torchvision.io import read_video
import warnings
from sklearn.preprocessing import LabelEncoder
import torch.nn.functional as F
from tqdm import tqdm

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/16", device=device)

In [3]:
import torch
import torch.nn as nn

class classifier(nn.Module):
    def __init__(self, input_size, output_size):
        super(classifier, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, x):
        x = torch.mean(x, dim=1)
        x = self.fc(x)
        return x

In [4]:
label_encoder = LabelEncoder()
warnings.filterwarnings("ignore")
def extract_frames(video_tensor, num_frames):
    total_frames = video_tensor.shape[0]
    
    if num_frames == total_frames:
        return video_tensor  # Return all frames if requested number is greater or equal to total frames
    elif num_frames > total_frames:
        last_frame = video_tensor[-1:]  # Select the last frame
        num_repeats = num_frames - total_frames
        repeated_frames = last_frame.repeat(num_repeats, 1, 1, 1)  # Repeat the last frame to match num_frames
        selected_frames = torch.cat([video_tensor, repeated_frames], dim=0)  # Concatenate the repeated frames
        return selected_frames
    else:
        # Calculate indices for equally spaced frames
        indices = torch.linspace(0, total_frames - 1, num_frames).round().long()
        
        # Extract frames based on the calculated indices
        selected_frames = video_tensor[indices]
        
        return selected_frames


class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, dataframe, num_frames, type, transform=None):
        self.transform = transform
        self.num_frames = num_frames
        self.dataframe = dataframe[dataframe['type'] == type].reset_index(drop=True) 
        self.dataframe['index_label'] = label_encoder.fit_transform(self.dataframe['label'])
    def __len__(self):
        return len(self.dataframe)
    def __getitem__(self, idx):
        video, audio, info = read_video(self.dataframe['video_path'][idx], pts_unit="sec")
        video = extract_frames(video, self.num_frames)
        if video.shape[0] < self.num_frames:
            print(video.shape[0])
        if self.transform is not None:
            video = self.transform(video)
        label = self.dataframe['index_label'][idx]
        return video, label

In [5]:
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
video_size = 224
train_transform = T.Compose(
        [
            ToTensorVideo(),  # C, T, H, W
            Permute(dims=[1, 0, 2, 3]),  # T, C, H, W
            RandAugment(magnitude=10, num_layers=2),
            Permute(dims=[1, 0, 2, 3]),  # C, T, H, W
            T.Resize(size=(video_size, video_size)),
            Normalize(mean=imagenet_mean, std=imagenet_std),
        ]
    )

test_transform = T.Compose(
    [
        ToTensorVideo(),
        T.Resize(size=(video_size, video_size)),
        Normalize(mean=imagenet_mean, std=imagenet_std),
    ]
)

In [6]:
hmdb = pd.read_csv("/hahmwj/csv_files/hmdb.csv")
train = VideoDataset(hmdb, 2, 'train', train_transform)
val = VideoDataset(hmdb, 2, 'valid', test_transform)
test = VideoDataset(hmdb, 2, 'test', test_transform)

dataloader_train = torch.utils.data.DataLoader(
    train,
    batch_size=16,
    # sampler=torch.utils.data.DistributedSampler(dataset_train),
    num_workers=0,
    pin_memory=True,
    )
dataloader_val = torch.utils.data.DataLoader(
# torch.utils.data.Subset(dataset_val, range(dist.get_rank(), len(dataset_val), dist.get_world_size())),
val,
batch_size=1,
shuffle=False,
num_workers=0,
pin_memory=True,
)

dataloader_test = torch.utils.data.DataLoader(
# torch.utils.data.Subset(dataset_val, range(dist.get_rank(), len(dataset_val), dist.get_world_size())),
test,
batch_size=1,
shuffle=False,
num_workers=0,
pin_memory=True,
)

In [7]:
model = model.visual
for name, param in model.named_parameters():
    param.requires_grad = False
classifier_model = classifier(512, 51)
optimizer = optim.AdamW(classifier_model.parameters(), lr=1e-4, weight_decay=0.9)

In [8]:
model.cuda()
classifier_model.cuda()
train_losses = []
train_acc1s = []
train_acc5s = []

val_losses = []
val_acc1s = []
val_acc5s = []

for epoch in range(50):
   model.train()
   for data, labels in tqdm(dataloader_train):
      data, labels = data.cuda(), labels.cuda()
      optimizer.zero_grad()
      B, T = data.shape[0], data.shape[2]
      data = data.permute(0, 2, 1, 3, 4).flatten(0, 1) # Permute: (B,T,C,H,W), Flatten: (B*T, C, H, W)
      with torch.cuda.amp.autocast():
         logits = model(data)
         logits = logits.contiguous().view(B, T, logits.shape[1])
         logits = classifier_model(logits)
      loss = F.cross_entropy(logits, labels)
      loss.backward()
      optimizer.step()
      train_losses.append(loss)

      acc1 = (logits.topk(1, dim=1)[1] == labels.view(-1, 1)).sum(dim=-1).float().mean().item() * 100
      acc5 = (logits.topk(5, dim=1)[1] == labels.view(-1, 1)).sum(dim=-1).float().mean().item() * 100
      train_acc1s.append(acc1)
      train_acc5s.append(acc5)
   print(f'Training epoch : {epoch}       Acc1 : {acc1}     Acc5 : {acc5}')

   model.eval()
   for data, labels in tqdm(dataloader_val, leave = False):
      data, labels = data.cuda(), labels.cuda()
      B, T = data.shape[0], data.shape[2]
      data = data.permute(0, 2, 1, 3, 4).flatten(0, 1) # Permute: (B,T,C,H,W), Flatten: (B*T, C, H, W)
      with torch.cuda.amp.autocast():
         logits = model(data)
         logits = logits.contiguous().view(B, T, logits.shape[1])
         logits = classifier_model(logits)
      loss = F.cross_entropy(logits, labels)
      val_losses.append(loss)

      acc1 = (logits.topk(1, dim=1)[1] == labels.view(-1, 1)).sum(dim=-1).float().mean().item() * 100
      acc5 = (logits.topk(5, dim=1)[1] == labels.view(-1, 1)).sum(dim=-1).float().mean().item() * 100
      val_acc1s.append(acc1)
      val_acc5s.append(acc5)
   print(f'Validation epoch : {epoch}     Acc1 : {acc1}     Acc5 : {acc5}')

100%|██████████| 270/270 [03:08<00:00,  1.43it/s]


Training epoch : 0       Acc1 : 11.11111119389534     Acc5 : 44.44444477558136


                                                   

Validation epoch : 0     Acc1 : 0.0     Acc5 : 0.0


100%|██████████| 270/270 [03:05<00:00,  1.46it/s]


Training epoch : 1       Acc1 : 22.22222238779068     Acc5 : 77.77777910232544


                                                   

Validation epoch : 1     Acc1 : 0.0     Acc5 : 0.0


100%|██████████| 270/270 [03:17<00:00,  1.37it/s]


Training epoch : 2       Acc1 : 11.11111119389534     Acc5 : 66.66666865348816


                                                   

Validation epoch : 2     Acc1 : 0.0     Acc5 : 0.0


100%|██████████| 270/270 [03:23<00:00,  1.33it/s]


Training epoch : 3       Acc1 : 33.33333432674408     Acc5 : 88.88888955116272


                                                   

Validation epoch : 3     Acc1 : 0.0     Acc5 : 0.0


100%|██████████| 270/270 [03:15<00:00,  1.38it/s]


Training epoch : 4       Acc1 : 22.22222238779068     Acc5 : 77.77777910232544


                                                   

Validation epoch : 4     Acc1 : 0.0     Acc5 : 0.0


100%|██████████| 270/270 [03:12<00:00,  1.40it/s]


Training epoch : 5       Acc1 : 44.44444477558136     Acc5 : 88.88888955116272


                                                   

Validation epoch : 5     Acc1 : 0.0     Acc5 : 0.0


100%|██████████| 270/270 [03:22<00:00,  1.33it/s]


Training epoch : 6       Acc1 : 44.44444477558136     Acc5 : 88.88888955116272


                                                   

Validation epoch : 6     Acc1 : 0.0     Acc5 : 0.0


  1%|          | 2/270 [00:02<05:44,  1.28s/it]


KeyboardInterrupt: 