In [1]:
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import os
from torch.utils.data import Dataset
from PIL import Image
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import clip

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
IMG_PATH='frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/'
CAP_PATH='captions/Labels-caption.csv'


In [3]:
df=pd.read_csv("captions/Labels-caption.csv")
df_cleaned=df[["gameTime","anonymized"]]
def clean_gameTime(time):
    return time.replace(":","_")
df_cleaned['gameTime']=df_cleaned['gameTime'].apply(clean_gameTime)
df_cleaned=df_cleaned.sort_values(by=['gameTime'])
print(df_cleaned.head())

      gameTime                                         anonymized
123  1 - 10_37  [PLAYER] ([TEAM]) whips the ball in, but it fa...
122  1 - 10_51  [PLAYER] ([TEAM]) sends a long ball in, but [P...
121  1 - 11_24  [PLAYER] ([TEAM]) was too forceful with his ta...
120  1 - 12_36  [TEAM] are playing possession football and con...
119  1 - 13_55  [PLAYER] ([TEAM]) releases [PLAYER], who latch...


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['gameTime']=df_cleaned['gameTime'].apply(clean_gameTime)


In [4]:
caption_arr=dict(zip(df_cleaned['gameTime'],df_cleaned['anonymized']))
print(caption_arr)

{'1 - 10_37': "[PLAYER] ([TEAM]) whips the ball in, but it fails to reach any of his teammates as the opposition's defence averts the threat.", '1 - 10_51': '[PLAYER] ([TEAM]) sends a long ball in, but [PLAYER] comfortably gathers the ball.', '1 - 11_24': '[PLAYER] ([TEAM]) was too forceful with his tackle and [REFEREE] interrupted the game to signal a free kick. A free kick to [TEAM].', '1 - 12_36': '[TEAM] are playing possession football and controlling the game at present. They are exchanging inch-perfect passes from player to player, making it difficult for the opposition to win the ball.', '1 - 13_55': "[PLAYER] ([TEAM]) releases [PLAYER], who latched on to his perfect through pass and can continue with his team's attack.", '1 - 14_30': "[PLAYER] ([TEAM]) releases [PLAYER], who latched on to his perfect through pass and can continue with his team's attack.", '1 - 14_50': 'What a goal! [PLAYER] plays it to [PLAYER] ([TEAM]), who finds himself unmarked inside the box and slots a fir

In [5]:
frame_arr={}
for folder_name in os.listdir(IMG_PATH):
    folder_path=os.path.join(IMG_PATH,folder_name)
    for img in os.listdir(folder_path):
        img_path=os.path.join(folder_path,img)
        if(folder_name not in frame_arr):
            frame_arr[folder_name]=[img_path]
        else:
            frame_arr[folder_name].append(img_path)  

In [6]:
print(frame_arr)

{'1 - 10_37': ['frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/1 - 10_37\\frame_15180.png', 'frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/1 - 10_37\\frame_15192.png', 'frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/1 - 10_37\\frame_15204.png', 'frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/1 - 10_37\\frame_15216.png', 'frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/1 - 10_37\\frame_15228.png', 'frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/1 - 10_37\\frame_15240.png', 'frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/1 - 10_37\\frame_15252.png', 'frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/1 - 10_37\\frame_15264.png', 'frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/1 - 10_37\\frame_15276.png', 'frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/1 - 10_37\\frame_15288.png', 'frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/1 - 10_37\\frame_15300.png', 'frames/2014-12-10 - 22-45 Barcelona 3 - 1 Paris SG/1 - 10_37\\frame_15312.png', 'frames/2014-

In [19]:
from transformers import ViTImageProcessor
from transformers import CLIPModel,CLIPImageProcessor
from transformers import ViTModel

img_processor=CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model,_=clip.load("ViT-B/32", device='cpu')
vit_model=ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
device="cuda" if torch.cuda.is_available() else 'cpu'
class VideoCaptionDataset(Dataset):
    def __init__(self,img,caption):
        self.img_path=img
        self.caption=caption
        self.ts=list(self.img_path.keys())
        self.transform=transforms.Compose([transforms.Resize((224, 224)),  
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    def __len__(self):
        return len(self.img_path)
    def __getitem__(self,index):
        ts=self.ts[index]
        caption=self.caption[ts]
        caption=caption[:77]
        img_paths=self.img_path[ts][:60]
        images=[]
        for img_path in img_paths:
            img=Image.open(img_path)
            img=self.transform(img)
            # img=vit_model(img)
            images.append(img)
        images=torch.stack(images)
        images=images.mean(dim=0)
        caption=clip.tokenize([caption]).squeeze(0).to(device)
        # images_encoded=clip_model.encode_image(images)
        # cap_encoded=clip_model.encode_text(clip.tokenize([caption]).to('cpu'))
        return images,caption


In [8]:
for data in VideoCaptionDataset(frame_arr,caption_arr):
    i,c=data
    print(i.shape)
    break

torch.Size([60, 512])


In [9]:
class TemporalFusionModule(nn.Module):
    def __init__(self, input_size, num_scales, lstm_hidden_size, num_heads, num_layers, d_ff):
        super(TemporalFusionModule, self).__init__()
        self.num_scales = num_scales
        self.downscales = nn.ModuleList([
            nn.Conv1d(input_size, input_size, kernel_size=3, stride=2, padding=1) for _ in range(num_scales)
        ])
        self.fusion_blocks = nn.ModuleList([
            nn.TransformerEncoder(nn.TransformerEncoderLayer(input_size, num_heads, d_ff), num_layers) for _ in range(num_scales)
        ])
        self.upscales = nn.ModuleList([
            nn.ConvTranspose1d(input_size, input_size, kernel_size=4, stride=2, padding=1) for _ in range(num_scales)
        ])
        self.lstm = nn.LSTM(input_size, lstm_hidden_size, num_layers, batch_first=True)
        self.out = nn.Linear(lstm_hidden_size, input_size)
        
    def forward(self, x):
        # Downscale
        downscale_outputs = []
        for i in range(self.num_scales):
            downscale_outputs.append(self.downscales[i](x))
            x = F.avg_pool1d(x, kernel_size=2, stride=2)
        
        # Temporal fusion
        fusion_outputs = []
        for i in range(self.num_scales):
            fusion_outputs.append(self.fusion_blocks[i](downscale_outputs[i]))
        
        # Upscale
        for i in range(self.num_scales):
            fusion_outputs[i] = self.upscales[i](fusion_outputs[i])
        
        # Aggregate features
        fused_features = torch.cat(fusion_outputs, dim=2)
        
        # LSTM
        lstm_out, _ = self.lstm(fused_features)
        
        # Final output
        output = self.out(lstm_out[:, -1, :])  # taking only the last time step
        
        return output

In [10]:
input_size = 512
num_scales = 3
lstm_hidden_size = 512
num_heads = 8
num_layers = 2
d_ff = 2048

fusion_module = TemporalFusionModule(input_size, num_scales, lstm_hidden_size, num_heads, num_layers, d_ff)

print(fusion_module)



TemporalFusionModule(
  (downscales): ModuleList(
    (0-2): 3 x Conv1d(512, 512, kernel_size=(3,), stride=(2,), padding=(1,))
  )
  (fusion_blocks): ModuleList(
    (0-2): 3 x TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
          )
          (linear1): Linear(in_features=512, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=512, bias=True)
          (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
      )
    )
  )
  (upscales): ModuleList(
    (0-2): 3 x ConvTranspose1d(512, 512, kernel_size=(4,), stride=(2,

In [11]:
# for data in VideoCaptionDataset(frame_arr,caption_arr):
#     i,c=data
#     input_size = 512
#     num_scales = 3
#     lstm_hidden_size = 512
#     num_heads = 8
#     num_layers = 2
#     d_ff = 2048
#     fusion_module = TemporalFusionModule(input_size, num_scales, lstm_hidden_size, num_heads, num_layers, d_ff)
#     fused_visual_features = fusion_module(i.transpose(1, 2) )

#     break


In [12]:
class VideoCaptionModel(nn.Module):
    def __init__(self):
        super(self,VideoCaptionModel)
        
    def __init__(self):
        super(VideoCaptionModel, self).__init__()
        # Load the CLIP model
        self.clip_model, _ = clip.load("ViT-B/32", device='cuda' if torch.cuda.is_available() else 'cpu')

    def forward(self, img, text):
        # Image Encoding
        with torch.no_grad():
            image_features = self.clip_model.encode_image(img)

        img_encoding = image_features

        # Text Encoding
        with torch.no_grad():
            text_features = self.clip_model.encode_text(text)

        text_encoding = text_features

        return img_encoding, text_encoding



In [13]:
# dataset=VideoCaptionDataset(frame_arr,caption_arr)
# dataloader = DataLoader(dataset, batch_size=1)

# for data in dataloader:
#     img,c=data
#     temporal_fusion_module = TemporalFusionModule(input_dim=512, num_scales=3, hidden_dim=256, output_dim=128)
#     output = temporal_fusion_module(img)
#     print(output)
#     break

In [21]:
dataset=VideoCaptionDataset(frame_arr,caption_arr)
dataloader = DataLoader(dataset, batch_size=4)
model=VideoCaptionModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs=10
for epoch in range(epochs):
    loss=0.0
    running_count=0
    for i,data in enumerate(dataloader,0):
        img,c=data
        img, c = img.to(device), c.to(device)

        # print(img.shape)
        # print(c.shape)
    #     break
    # break

        img_features,txt_features=model(img,c)

        cosine_sim=img_features@txt_features.t()

        logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        logit_scale = logit_scale.exp() 

        logits_per_image = logit_scale * cosine_sim
        logits_per_text = logits_per_image.t()
        labels = torch.arange(logits_per_image.shape[0], dtype=torch.long)
        total_loss = (
        torch.nn.functional.cross_entropy(logits_per_image, labels) +
        torch.nn.functional.cross_entropy(logits_per_text, labels)
    ) / 2
        total_loss.backward()
        optimizer.step()
        loss+=total_loss.item()
        running_count+=1
        if(i%5==1):
            print(f"Running loss:{loss/running_count}")
            loss=0.0
            running_count=0


Running loss:25.213332176208496
Running loss:28.740072631835936
Running loss:24.14544258117676
Running loss:32.7457218170166
Running loss:22.065915679931642
Running loss:33.20024108886719
Running loss:25.213332176208496
Running loss:28.740072631835936
Running loss:24.14544258117676
Running loss:32.7457218170166
Running loss:22.065915679931642
Running loss:33.20024108886719
Running loss:25.213332176208496
Running loss:28.740072631835936
Running loss:24.14544258117676
Running loss:32.7457218170166
Running loss:22.065915679931642
Running loss:33.20024108886719
Running loss:25.213332176208496
Running loss:28.740072631835936
Running loss:24.14544258117676
Running loss:32.7457218170166
Running loss:22.065915679931642
Running loss:33.20024108886719
Running loss:25.213332176208496


KeyboardInterrupt: 