In [None]:
from helperFunctions import VectorQuantizeImage, VecQVAE, FrameDataset
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image
from torch.utils.data import DataLoader, Dataset
import numpy as np
from einops import rearrange
import urllib
import io
from PIL import Image, ImageSequence
import pandas as pd
import torch.nn as nn
import os
from transformers import BertTokenizer, BertModel
from collections import OrderedDict
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
import wandb


In [None]:
wandb.login()

wandb.init(
    project="T2V-Decoder",  
    name="experiment-1-thread-1",    
    # id="m6ms1f4w",  
    # resume="allow",
)

In [2]:
dataset = pd.read_csv("/Users/ishananand/Desktop/Text-To-Video-Generation/data/modified_tgif.csv")
dataset = dataset[(dataset['frames'] <= 40) & (dataset['frames'] > 15)].copy().reset_index(drop=True)
dataset = dataset[:10000] 

dataset.shape

(10000, 3)

In [3]:
CACHEDIR = "/Users/ishananand/Desktop/Text-To-Video-Generation/data/cacheGIF"

class FrameDataset(Dataset):
    def __init__(self, data, totalSequence=40, transform=None, cache_dir=CACHEDIR):
        super().__init__()
        self.data = data
        self.transform = transform
        self.totalSequence = totalSequence
        self.cache_dir = cache_dir
        os.makedirs(self.cache_dir, exist_ok=True)

    def __len__(self):
        return len(self.data)

    def npArray(self, index):
        row = self.data.iloc[index]
        url = row['url']
        resp = urllib.request.urlopen(url)
        image_data = resp.read()
        img = Image.open(io.BytesIO(image_data))

        frames = []
        for frame in ImageSequence.Iterator(img):
            frame_rgb = frame.convert("RGB")
            frames.append(np.array(frame_rgb))
        return frames

    def __getitem__(self, index):
        row = self.data.iloc[index]
        url = row['url']
        caption = row['caption']
        totalframes = row['frames']
        gif_path = os.path.join(self.cache_dir, f'{index}.gif')

        if not os.path.exists(gif_path):
            resp = urllib.request.urlopen(url)
            image_data = resp.read()
            with open(gif_path, 'wb') as f:
                f.write(image_data)
        else:
            with open(gif_path, 'rb') as f:
                image_data = f.read()

        img = Image.open(io.BytesIO(image_data))

        frames = []
        for frame in ImageSequence.Iterator(img):
            frame_rgb = frame.convert("RGB")
            frames.append(np.array(frame_rgb))

        if len(frames) < self.totalSequence:
            frames += [frames[-1]] * (self.totalSequence - len(frames))
        else:
            frames = frames[:self.totalSequence]

        if self.transform:
            tensorFrames = torch.stack([
                self.transform(Image.fromarray(frame)) for frame in frames
            ])
            tensorFrames = tensorFrames / 255.0
            return tensorFrames, caption
        else:
            return frames, caption

 

tranform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

fdata = FrameDataset(dataset, transform=tranform)

X, Y = fdata.__getitem__(1000)
print(X.shape, Y)


torch.Size([40, 3, 128, 128]) a man smiles and nods his head.


In [6]:
codeBookdim = 1024
embedDim = 256
hiddenDim = 512
inChannels = 3
tranform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
modelA = VecQVAE(inChannels = inChannels, hiddenDim = hiddenDim, codeBookdim = codeBookdim, embedDim = embedDim).to(device)

modelValA = torch.load("/Users/ishananand/Desktop/Text-To-Video-Generation/models/VQVAE-GIF-thread-52.pt", map_location=torch.device('cpu'))
epochs = 1000
modelA.load_state_dict(modelValA['model_state_dict'])

test = torch.randn(1, 40, 3, 128, 128)
quantized_latents, decoderOut, codebook_loss, commitment_loss, encoding_indices, perplexity, diversity_loss = modelA(test)
quantized_latents.shape, decoderOut.shape, codebook_loss, commitment_loss, encoding_indices.shape, perplexity, diversity_loss

(torch.Size([40, 256, 32, 32]),
 torch.Size([40, 3, 128, 128]),
 tensor(0.0217, grad_fn=<MseLossBackward0>),
 tensor(0.0217, grad_fn=<MseLossBackward0>),
 torch.Size([40960]),
 tensor(546.5923),
 tensor(0.0906))

In [7]:
class Text2Video(nn.Module):
    def __init__(self, embedDimension, sequenceLength, codeBookDim, hiddenLayers, heads, feedForwardDim, text_max_length=128, drop=0.15):
        super().__init__()
        self.max_length = text_max_length
        self.embedDimension = embedDimension
        self.codeBookDim = codeBookDim
        self.sequenceLength = sequenceLength

        self.berTokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        self.bertModel = BertModel.from_pretrained("bert-base-uncased")
        
        for param in self.bertModel.parameters():
            param.requires_grad = False
            
        self.hiddenSize = self.bertModel.config.hidden_size
        
        self.textProjection = nn.Linear(self.hiddenSize, embedDimension)
        self.positionalEmbedding = nn.Embedding(self.max_length, embedDimension)
        self.temporalPositionalEmbedding = nn.Embedding(self.sequenceLength, embedDimension)

        self.textMultiAttention = nn.MultiheadAttention(embedDimension, heads, dropout=drop, batch_first=True)
        self.textlayerNorm = nn.LayerNorm(embedDimension)

        self.decoder_layer = nn.TransformerDecoderLayer(
            d_model=embedDimension, 
            nhead=heads, 
            dim_feedforward=feedForwardDim, 
            dropout=drop, 
            batch_first=True,
            activation='gelu'
        )
        self.decoder = nn.TransformerDecoder(decoder_layer=self.decoder_layer, num_layers=hiddenLayers)
        self.decoder_norm = nn.LayerNorm(embedDimension)
        self.predictIndices = nn.Linear(embedDimension, codeBookDim)

    def forward(self, text, device):
        if isinstance(text, str):
            text = [text]
        elif isinstance(text, (list, tuple)):
            pass
        else:
            raise ValueError(f"Give string or list of strings, recieved this {type(text)}")
            
        batchSize = len(text)

        tokens = self.berTokenizer(text, return_tensors='pt', padding='max_length', 
                                  truncation=True, max_length=self.max_length).to(device)
        with torch.no_grad():
            outputs = self.bertModel(**tokens)
            lastLayerEMbeddings = outputs.last_hidden_state
        
        positions = torch.arange(0, self.max_length, device=lastLayerEMbeddings.device).unsqueeze(0).expand(batchSize, -1)
        positionalEmbeddings = self.positionalEmbedding(positions)
        
        textEmbeddings = self.textProjection(lastLayerEMbeddings)
        textEmbeddings = textEmbeddings + positionalEmbeddings
        textEmbeddings = self.textlayerNorm(textEmbeddings)
        
        temporalPositions = torch.arange(0, self.sequenceLength, device=device).unsqueeze(0).expand(batchSize, -1)
        temporal_queries = self.temporalPositionalEmbedding(temporalPositions)
        
        frame_text_features, _ = self.textMultiAttention(
            query=temporal_queries,
            key=textEmbeddings,
            value=textEmbeddings
        )
        
        causal_mask = torch.triu(torch.ones(self.sequenceLength, self.sequenceLength, device=device), diagonal=1).bool()
        
        decoderOut = self.decoder(
            tgt=frame_text_features, 
            memory=textEmbeddings, 
            tgt_mask=causal_mask
        )
        decoderOut = self.decoder_norm(decoderOut)
        
        encoding_indices = self.predictIndices(decoderOut)
        return encoding_indices
    
t2v = Text2Video(embedDimension=256, sequenceLength=40, codeBookDim=1024, hiddenLayers=6, heads=8, feedForwardDim=2048, drop=0.15)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
t2v.to(device)

text = ["a cat jumping on a bed", "A man Walking", "He is Running"]
logits = t2v(text, device)
logits.shape

torch.Size([3, 40, 1024])

In [None]:
BATCH_SIZE = 2
embedDimension = 256
sequenceLength = 40
codeBookDim = 1024
hiddenLayers=6
heads=8
feedForwardDim=2048
drop=0.15
learning_rate = 3e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tranform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])


torchDataset = FrameDataset(dataset, totalSequence=sequenceLength, transform=tranform)
dataloader = DataLoader(torchDataset, batch_size=BATCH_SIZE, shuffle = True, num_workers=8, persistent_workers=True)
model = Text2Video(embedDimension=embedDimension, sequenceLength=sequenceLength, codeBookDim=codeBookDim, hiddenLayers=hiddenLayers, heads=heads, feedForwardDim=feedForwardDim, drop=drop)
model = torch.nn.DataParallel(model)
model.to(device)

lossFn =  nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(params=model.parameters(), lr=learning_rate)#, weight_decay=1e-5)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)


epochs = 1000

In [74]:
start_epoch = 0

checkpoint_path = "/Users/ishananand/Desktop/Text-To-Video-Generation/videoModels/t2VModel.pt"
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    modelA.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")
else:
    print("Loading pretrained model...")
    # modelValA = torch.load("./projects/t2v-gif/models/t2VModel.pt", map_location=torch.device('cpu'))
    # modelA.load_state_dict(modelValA)

modelA = torch.nn.DataParallel(modelA)
modelA.to(device)

for each_epoch in range(start_epoch, epochs):
    modelA.train()
    decoderLoss = 0.0
    
    loop = tqdm(dataloader, f"{each_epoch}/{epochs}")

    for X, Y in loop:
        # print(X.shape, Y)
        with torch.no_grad():
            _, _, _, _, encoding_indices, _, _ = modelA(X)
        
        y_pred = model(Y, device)
        break
        print(y_pred.shape, encoding_indices.shape)
        y_pred_reshaed = rearrange(y_pred, 'b t d -> (b t) d')
        encoding_indices_flat = rearrange(y_pred, 'b t d -> (b t) d', b = BATCH_SIZE, t = sequenceLength, d = codeBookdim)
        loss = lossFn(y_pred_reshaed, encoding_indices_flat)
        lossVal += loss.item()
   
        
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        loop.set_postfix({
            "TotalL": f"{decoderLoss}"
        })

    decoderLoss /= len(dataloader)   
    
    torch.save({
        'epoch': each_epoch,
        'model_state_dict': modelA.module.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
    }, checkpoint_path)
    break
    
    # wandb.log({
    #     "Learning Rate": optimizer.param_groups[0]['lr'],
    #     "Decoder Loss": decoderLoss
    # })
    scheduler.step()
 

Loading pretrained model...


0/1000:   0%|          | 0/5000 [02:21<?, ?it/s]


NameError: name 'optimizer' is not defined

In [None]:
lossVal