In [None]:
import torch
import torch.nn as nn
from einops import rearrange
from tqdm import tqdm
import torch.nn.functional  as Fn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from io import BytesIO
from IPython.core.display import HTML
from IPython.display import Image as IPyImage, display
from PIL import Image, ImageSequence
import matplotlib.pyplot as plt
import os
from piq import ssim
import pandas as pd
import urllib
from torchvision.models import vgg16
import io
import cv2
import hashlib
import kornia

# import wandb
# wandb.login()

In [28]:
# wandb.init(
#     project="Transformer-Decoder",  
#     name="experiment-2",    
#     # id="uqcub7jq",  
#     # resume="allow",
#     # config={                       
#     #     "epochs": 1000,
#     #     "batch_size": 64,
#     # }
# )

In [3]:
class VectorQuantizeImage(nn.Module):
    def __init__(self, codeBookDim = 64, embeddingDim = 32, decay = 0.99, eps = 1e-5):
        super().__init__()

        self.codeBookDim = codeBookDim
        self.embeddingDim = embeddingDim
        self.decay = decay
        self.eps = eps
        self.dead_codeBook_threshold = codeBookDim * 0.6

        self.codebook = nn.Embedding(codeBookDim, embeddingDim)
        nn.init.xavier_uniform_(self.codebook.weight.data)

        self.register_buffer('ema_Count', torch.zeros(codeBookDim))
        self.register_buffer('ema_Weight', self.codebook.weight.data.clone())

    def forward(self, x):
        x_reshaped = x.view(-1, self.embeddingDim)

        distance = (torch.sum(x_reshaped**2, dim=1, keepdim=True) 
                    + torch.sum(self.codebook.weight**2, dim=1)
                    - 2 * torch.matmul(x_reshaped, self.codebook.weight.t()))
        
        encoding_indices = torch.argmin(distance, dim=1) 
        encodings = Fn.one_hot(encoding_indices, self.codeBookDim).type(x_reshaped.dtype)
        quantized = torch.matmul(encodings, self.codebook.weight)

        if self.training:
            self.ema_Count = self.decay * self.ema_Count + (1 - self.decay) * torch.sum(encodings, 0)
            
            x_reshaped_sum = torch.matmul(encodings.t(), x_reshaped.detach())
            self.ema_Weight = self.decay * self.ema_Weight + (1 - self.decay) * x_reshaped_sum
            
            n = torch.clamp(self.ema_Count, min=self.eps)
            updated_embeddings = self.ema_Weight / n.unsqueeze(1)
            self.codebook.weight.data.copy_(updated_embeddings)

        
        avg_probs = torch.mean(encodings, dim=0)
        log_encoding_sum = -torch.sum(avg_probs * torch.log(avg_probs + 1e-10))
        perplexity = torch.exp(log_encoding_sum)

        entropy = log_encoding_sum
        normalized_entropy = entropy / torch.log(torch.tensor(self.codeBookDim, device=x.device))
        diversity_loss = 1.0 - normalized_entropy

        return quantized, encoding_indices, perplexity, diversity_loss
        
        
vq = VectorQuantizeImage(codeBookDim=64,embeddingDim=32)
rand = torch.randn(1024,32)
vq(rand)

(tensor([[-1.3374e-02, -2.4344e-01,  9.7756e-02,  ...,  9.8517e-02,
           1.0220e-01, -2.1899e-01],
         [-1.2528e-01,  9.2228e-02, -1.5133e-01,  ..., -6.6822e-03,
           2.3273e-01,  1.0878e-01],
         [-2.2598e-01,  2.3129e-01, -1.4342e-01,  ...,  2.2462e-01,
          -2.3867e-01, -1.8427e-01],
         ...,
         [ 1.9732e-01,  1.4822e-01, -7.5449e-02,  ..., -1.0177e-02,
          -1.8945e-01, -3.1792e-02],
         [ 1.2054e-01, -2.3742e-01,  1.6465e-01,  ..., -1.7200e-01,
          -8.2114e-02, -2.1017e-04],
         [ 1.9834e-01,  1.6758e-02,  5.6952e-02,  ...,  2.2517e-01,
          -6.4743e-03, -1.5757e-01]], grad_fn=<MmBackward0>),
 tensor([26, 18, 63,  ..., 33, 37, 58]),
 tensor(59.9182),
 tensor(0.0158))

In [4]:
32*10*256

81920

In [5]:
class VecQVAE(nn.Module):
    def __init__(self, inChannels = 1, hiddenDim = 32, codeBookdim = 128, embedDim = 128):
        super().__init__()
        self.inChannels = inChannels
        self.hiddenDim = hiddenDim
        self.codeBookdim = codeBookdim
        self.embedDim = embedDim

        self.encoder = nn.Sequential(
            nn.Conv2d(inChannels, hiddenDim, 4, 2, 1), 
            nn.BatchNorm2d(hiddenDim),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(hiddenDim, hiddenDim, 3, 1, 1),
            nn.BatchNorm2d(hiddenDim),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(hiddenDim, 2 * hiddenDim, 4, 2, 1),
            nn.BatchNorm2d(2 * hiddenDim),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(2 * hiddenDim, 2 * hiddenDim, 3, 1, 1),
            nn.BatchNorm2d(2 * hiddenDim),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(2 * hiddenDim, embedDim, 1),
        )

        self.vector_quantize = VectorQuantizeImage(codeBookDim=codeBookdim,embeddingDim=embedDim)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(embedDim, 2 * hiddenDim, 4, 2, 1),
            nn.BatchNorm2d(2 * hiddenDim),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(2 * hiddenDim, 2 * hiddenDim, 3, 1, 1),
            nn.BatchNorm2d(2 * hiddenDim),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(2 * hiddenDim, hiddenDim, 4, 2, 1),
            nn.BatchNorm2d(hiddenDim),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(hiddenDim, hiddenDim, 3, 1, 1),
            nn.BatchNorm2d(hiddenDim),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(hiddenDim, inChannels, 1),
            nn.Sigmoid()
        )

    def encodeImage(self, x, noise_std = 0.15):
        if self.training:
            encodedOut = self.encoder(x)
            encodedOut = encodedOut + torch.randn_like(encodedOut) * noise_std
        else:
            encodedOut = self.encoder(x)

        return encodedOut

    def decodeImage(self, quantized_vector):
        decodedOut = self.decoder(quantized_vector)
        return decodedOut

    def forward(self, x):
        batch_size, time_frame, inChannels, height, width = x.shape

        x_frames = rearrange(x, 'b t c h w -> (b t) c h w')
        encodedOut = self.encodeImage(x_frames)
        batch_size_time_frame, encoded_channel, encoded_height, encoded_width = encodedOut.shape
        
        # print(f"Encoded Shape: {encodedOut.shape}")

        
        vectorize_input = rearrange(encodedOut, 'bt d h w -> (bt h w) d')
        quantized_vectors, encoding_indices, perplexity, diversity_loss  = self.vector_quantize(vectorize_input)
        codebook_loss = Fn.mse_loss(vectorize_input.detach(), quantized_vectors)
        commitment_loss = Fn.mse_loss(vectorize_input, quantized_vectors.detach())

        quantized_vectors = vectorize_input + (quantized_vectors - vectorize_input).detach()
        # print(f"CodeBook Loss: {codebook_loss} , Commitment Loss: {commitment_loss}")
        # print(f"Quantized SHape: {quantized_vectors.shape}")

        decoder_input = rearrange(quantized_vectors, '(bt h w) d -> bt d h w', bt = batch_size_time_frame, d = encoded_channel, h = encoded_height, w = encoded_width)
        # print(f"Decoded Input SHape: {decoder_input.shape}")
        decodedOut = self.decodeImage(decoder_input)

        # print(f"Decoded SHape: {decodedOut.shape}")
        
        return decoder_input, decodedOut, codebook_loss, commitment_loss, encoding_indices, perplexity, diversity_loss

VQ = VecQVAE(inChannels = 3, hiddenDim = 256, codeBookdim = 128, embedDim = 64)
test = torch.randn(32, 10, 3, 64, 64)
quantized_latents, decoderOut, codebook_loss, commitment_loss, encoding_indices, perplexity, diversity_loss = VQ(test)
quantized_latents.shape, decoderOut.shape, codebook_loss, commitment_loss, encoding_indices.shape, perplexity, diversity_loss

(torch.Size([320, 64, 16, 16]),
 torch.Size([320, 3, 64, 64]),
 tensor(0.1738, grad_fn=<MseLossBackward0>),
 tensor(0.1738, grad_fn=<MseLossBackward0>),
 torch.Size([81920]),
 tensor(46.8948),
 tensor(0.2069))

In [6]:
class SelfAttentionBlock(nn.Module):
    def __init__(self, inChannels, heads = 8):
        super().__init__()
        self.query = nn.Conv2d(inChannels, inChannels // heads, 1)
        self.key = nn.Conv2d(inChannels, inChannels // heads, 1)
        self.value = nn.Conv2d(inChannels, inChannels, 1)

        self.coeff = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch, channel, height, width = x.shape
        q = self.query(x).view(batch, -1, height * width).permute(0, 2, 1)
        k = self.key(x).view(batch, -1, height * width)
        v = self.value(x).view(batch, -1, height * width)
        attn = torch.matmul(q, k)                                          
        attn = Fn.softmax(attn, dim=-1)
        attn_reshaped = attn.permute(0, 2, 1) 
        
        out = torch.matmul(v, attn_reshaped)                        
        out = out.view(batch, channel, height, width)                    
        out = self.coeff * out + x                                      
        return out

rad = torch.randn(10, 128, 64, 64)
sAtt = SelfAttentionBlock(128, 4)
out = sAtt(rad)
out.shape

torch.Size([10, 128, 64, 64])

In [7]:
class VecQVAE(nn.Module):
    def __init__(self, inChannels=3, hiddenDim=256, codeBookdim=256, embedDim=128):
        super().__init__()
        self.inChannels = inChannels
        self.hiddenDim = hiddenDim
        self.codeBookdim = codeBookdim
        self.embedDim = embedDim

        self.block1 = nn.Sequential(
            nn.Conv2d(inChannels, hiddenDim, 4, 2, 1),
            nn.BatchNorm2d(hiddenDim),
            nn.ReLU(inplace=True)
        )
        self.block2 = nn.Sequential(
            nn.Conv2d(hiddenDim, hiddenDim, 3, 1, 1),
            nn.BatchNorm2d(hiddenDim),
            nn.ReLU(inplace=True)
        )
        self.block3 = nn.Sequential(
            nn.Conv2d(hiddenDim, 2 * hiddenDim, 4, 2, 1),
            nn.BatchNorm2d(2 * hiddenDim),
            nn.ReLU(inplace=True),
            SelfAttentionBlock(2 * hiddenDim)
        )

        self.block4 = nn.Sequential(
            nn.Conv2d(2 * hiddenDim, 2 * hiddenDim, 3, 1, 1),
            nn.BatchNorm2d(2 * hiddenDim),
            nn.ReLU(inplace=True),
            SelfAttentionBlock(2 * hiddenDim)
        )

        self.block5 = nn.Sequential(
            nn.Conv2d(2 * hiddenDim, embedDim, 1)
        )

        self.vector_quantize = VectorQuantizeImage(codeBookDim=codeBookdim, embeddingDim=embedDim)

        self.block6 = nn.Sequential(
            nn.ConvTranspose2d(embedDim, 2 * hiddenDim, 1),
            nn.BatchNorm2d(2 * hiddenDim),
            nn.ReLU(inplace=True)
        )
        self.block7 = nn.Sequential(
            nn.Conv2d(2 * hiddenDim, 2 * hiddenDim, 3, 1, 1),
            nn.BatchNorm2d(2 * hiddenDim),
            nn.ReLU(inplace=True)
        )
        self.block8 = nn.Sequential(
            nn.ConvTranspose2d(2 * hiddenDim, hiddenDim, 4, 2, 1),
            nn.BatchNorm2d(hiddenDim),
            nn.ReLU(inplace=True)
        )
        self.block9 = nn.Sequential(
            nn.Conv2d(hiddenDim, hiddenDim, 3, 1, 1),
            nn.BatchNorm2d(hiddenDim),
            nn.ReLU(inplace=True)
        )
        self.block10 = nn.Sequential(
            nn.ConvTranspose2d(hiddenDim, hiddenDim // 2, 4, 2, 1),
            nn.BatchNorm2d(hiddenDim // 2),
            nn.ReLU(inplace=True)
        )
       
        self.outputlayer = nn.Sequential(
            nn.Conv2d(hiddenDim // 2, inChannels, 1),
            nn.Sigmoid()
        )

    def encodeImage(self, x, noise_std=0.15):
        if self.training:
            x1 = self.block1(x)
            x2 = self.block2(x1)
            x3 = self.block3(x2)
            x4 = self.block4(x3)
            encoded = self.block5(x4)
            encoded += torch.randn_like(encoded) * noise_std
        else:
            x1 = self.block1(x)
            x2 = self.block2(x1)
            x3 = self.block3(x2)
            x4 = self.block4(x3)
            encoded = self.block5(x4)
        return encoded, (x2, x3, x4)

    def decodeImage(self, quantized_vector, skips):
        x2, x3, x4 = skips
        # print(x2.shape, x3.shape, x4.shape)
        x = self.block6(quantized_vector)
        x = self.block7(x + x4)
        x = self.block8(x + x3)
        x = self.block9(x + x2)
        x = self.block10(x)
        return self.outputlayer(x)

    def forward(self, x):
        batch, timeFrames, channel, height, width = x.shape
        x = rearrange(x, 'b t c h w -> (b t) c h w')

        encoded, skips = self.encodeImage(x)
        batchTime, encodedChannels, encodedHeight, encodedWidth = encoded.shape
        encoder_reshaped = rearrange(encoded, 'bt d h w -> (bt h w) d')
        quantized_vectors, encoding_indices, perplexity, diversity_loss = self.vector_quantize(encoder_reshaped)

        codebook_loss = Fn.mse_loss(encoder_reshaped.detach(), quantized_vectors)
        commitment_loss = Fn.mse_loss(encoder_reshaped, quantized_vectors.detach())

        quantized_vectors = encoder_reshaped + (quantized_vectors - encoder_reshaped).detach()
        # print(quantized.shape)
        decoder_input = rearrange(quantized_vectors, '(bt h w) d -> bt d h w', bt=batch * timeFrames, d=encodedChannels, h=encodedHeight, w=encodedWidth)
        # print(decoder_input.shape)

        decodedOut= self.decodeImage(decoder_input, skips)
        return decoder_input, decodedOut, codebook_loss, commitment_loss, encoding_indices, perplexity, diversity_loss

VQ = VecQVAE(inChannels = 3, hiddenDim = 512, codeBookdim = 1024, embedDim = 128)
test = torch.randn(1, 10, 3, 64, 64)
quantized_latents, decoderOut, codebook_loss, commitment_loss, encoding_indices, perplexity, diversity_loss = VQ(test)
quantized_latents.shape, decoderOut.shape, codebook_loss, commitment_loss, encoding_indices.shape, perplexity, diversity_loss

(torch.Size([10, 128, 16, 16]),
 torch.Size([10, 3, 64, 64]),
 tensor(0.1665, grad_fn=<MseLossBackward0>),
 tensor(0.1665, grad_fn=<MseLossBackward0>),
 torch.Size([2560]),
 tensor(271.1230),
 tensor(0.1917))

In [30]:
dataset = pd.read_csv("/Users/ishananand/Desktop/Text-To-Video-Generation/data/modified_tgif.csv")
dataset = dataset[(dataset['frames'] <= 40) & (dataset['frames'] > 15)].copy().reset_index(drop=True)
# dataset = dataset[:10000] # thread 1 first 10000

dataset.shape

(78568, 3)

In [9]:
def getNumpyArray(dataset, index):
    url = dataset['url'][index]
    resp = urllib.request.urlopen(url)
    image = np.asarray(bytearray(resp.read()), dtype="uint8")
    image = cv2.imdecode(image, cv2.IMREAD_COLOR_RGB)

    return image

tImg = getNumpyArray(dataset, 10)
tImg.shape

(377, 500, 3)

In [10]:
def getNumpyArray(dataset, index):
    url = dataset.iloc[index]['url']
    resp = urllib.request.urlopen(url)
    data = resp.read()
    pil_image = Image.open(io.BytesIO(data))
    
    frames = []
    for frame in ImageSequence.Iterator(pil_image):
        frame = frame.convert("RGB")
        frame_np = np.array(frame)
        frames.append(frame_np)
    
    frames = np.array(frames)
    return frames


tImg = getNumpyArray(dataset, 2000)
tImg.shape

(23, 281, 500, 3)

In [11]:
print(dataset['caption'][2222])
HTML(f'<img src="{dataset['url'][2222]}" />')

a man is shaking a dog's hand. w


In [12]:
class FrameDataset(Dataset):
    def __init__(self, data, totalSequence = 40, transform = None):
        super().__init__()
        self.data = data
        self.transform = transform
        self.totalSequence = totalSequence

    def __len__(self):
        return len(self.data)
    
    def npArray(self, index):
        try:
            row = self.data.iloc[index]
            totalframes = self.data.iloc[index]['frames']
            url = row['url']
            resp = urllib.request.urlopen(url)
            image_data = resp.read()
            img = Image.open(io.BytesIO(image_data))
    
            frames = []
            for frame in ImageSequence.Iterator(img):
                frame_rgb = frame.convert("RGB")
                frames.append(np.array(frame_rgb))
    
            return frames
    
        except Exception as e:
            print(f"Error processing index {index}: {e}")
            fallback = torch.zeros((256, 256, 3), dtype=torch.uint8)
            return [fallback.numpy()]
    
    def __getitem__(self, index):
        # print(index)
        gif = self.npArray(index)
        caption = self.data.iloc[index]['caption']
        totalframes = len(gif)#self.data.iloc[index]['frames']
        
        if totalframes < self.totalSequence:
            gif += [gif[-1]] * (self.totalSequence - totalframes)

        tensorFrames = torch.stack([
            self.transform(Image.fromarray(frame)) for frame in gif
        ])

        tensorFrames = tensorFrames/255.0

        return tensorFrames, caption
    

tranform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

fdata = FrameDataset(dataset, transform=tranform)

X, Y = fdata.__getitem__(1)
print(X.shape, Y)

torch.Size([40, 3, 256, 256]) a man dressed in red is dancing.


In [85]:
class FrameDataset(Dataset):
    def __init__(self, data, totalSequence=40, transform=None, cache_dir='/Users/ishananand/Desktop/Text-To-Video-Generation/data/cacheGIF'):
        super().__init__()
        self.data = data
        self.transform = transform
        self.totalSequence = totalSequence
        self.cache_dir = cache_dir
        os.makedirs(self.cache_dir, exist_ok=True)

    def __len__(self):
        return len(self.data)

    def cache_path(self, index):
        url = self.data.iloc[index]['url']
        hash_name = hashlib.md5(url.encode()).hexdigest()
        return os.path.join(self.cache_dir, f'{hash_name}.pt')

    def npArray(self, index):
        row = self.data.iloc[index]
        url = row['url']
        resp = urllib.request.urlopen(url)
        image_data = resp.read()
        img = Image.open(io.BytesIO(image_data))

        frames = []
        for frame in ImageSequence.Iterator(img):
            frame_rgb = frame.convert("RGB")
            frames.append(np.array(frame_rgb))
        return frames

    def __getitem__(self, index):
        totalframes = self.data.iloc[index]['frames']
        caption = self.data.iloc[index]['caption']

        cache_path = self.cache_path(index)

        if os.path.exists(cache_path):
            tensorFrames = torch.load(cache_path)
            return tensorFrames, caption

        gif = self.npArray(index)
        caption = self.data.iloc[index]['caption']
        totalframes = self.data.iloc[index]['frames']

        if totalframes < self.totalSequence:
            gif += [gif[-1]] * (self.totalSequence - totalframes)
        

        tensorFrames = torch.stack([
            self.transform(Image.fromarray(frame)) for frame in gif
        ])

        tensorFrames = tensorFrames / 255.0
        torch.save(tensorFrames, cache_path)

        return tensorFrames, caption
    
tranform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

fdata = FrameDataset(dataset, transform=tranform)

X, Y = fdata.__getitem__(1)
print(X.shape, Y)

torch.Size([40, 3, 256, 256]) a man dressed in red is dancing.


In [None]:
class FrameDataset(Dataset):
    def __init__(self, data, totalSequence=40, transform=None, cache_dir='/Users/ishananand/Desktop/Text-To-Video-Generation/data/cacheGIF'):
        super().__init__()
        self.data = data
        self.transform = transform
        self.totalSequence = totalSequence
        self.cache_dir = cache_dir
        os.makedirs(self.cache_dir, exist_ok=True)

    def __len__(self):
        return len(self.data)

    def npArray(self, index):
        row = self.data.iloc[index]
        url = row['url']
        resp = urllib.request.urlopen(url)
        image_data = resp.read()
        img = Image.open(io.BytesIO(image_data))

        frames = []
        for frame in ImageSequence.Iterator(img):
            frame_rgb = frame.convert("RGB")
            frames.append(np.array(frame_rgb))
        return frames

    def __getitem__(self, index):
        row = self.data.iloc[index]
        url = row['url']
        caption = row['caption']
        totalframes = row['frames']
        gif_path = os.path.join(self.cache_dir, f'{index}.gif')

        if not os.path.exists(gif_path):
            resp = urllib.request.urlopen(url)
            image_data = resp.read()
            with open(gif_path, 'wb') as f:
                f.write(image_data)
        else:
            with open(gif_path, 'rb') as f:
                image_data = f.read()

        img = Image.open(io.BytesIO(image_data))

        frames = []
        for frame in ImageSequence.Iterator(img):
            frame_rgb = frame.convert("RGB")
            frames.append(np.array(frame_rgb))

        if len(frames) < self.totalSequence:
            frames += [frames[-1]] * (self.totalSequence - len(frames))
        else:
            frames = frames[:self.totalSequence]

        if self.transform:
            tensorFrames = torch.stack([
                self.transform(Image.fromarray(frame)) for frame in frames
            ])
            tensorFrames = tensorFrames / 255.0
            return tensorFrames, caption
        else:
            return frames, caption

    
tranform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
])

fdata = FrameDataset(dataset, transform=tranform)

X, Y = fdata.__getitem__(10)
print(X.shape, Y)

torch.Size([40, 3, 256, 256]) a woman is laughing and holding a man, the man is not laughing


In [63]:
dataset.shape[0]

78568

In [None]:
threads = []
i = 0
while i < dataset.shape[0]:
    threads.append(dataset[i:i+10000])
    i = i + 10000

threads[0].shape

(10000, 3)

In [61]:
cummulativeIndices = []

for i in range(len(threads)):
    indices = threads[i].index[threads[i]['frames'] > 35].tolist()
    cummulativeIndices.extend(indices)

print(len(cummulativeIndices))
cummulativeData = dataset.loc[cummulativeIndices]
cummulativeData.shape

13104


(13104, 3)

In [84]:
cummulativeData.head()

Unnamed: 0,url,caption,frames
11,https://38.media.tumblr.com/88bd6013c943b41c93...,a man with lights on his jacket watching a lar...,36
15,https://38.media.tumblr.com/88976e0d8b0e068ebd...,the clouds are moving next to the full moon,38
34,https://38.media.tumblr.com/70f4c1fc63e69bafb2...,a girl in a football uniform is dancing and si...,38
40,https://38.media.tumblr.com/6fb91b85be51fd128b...,a man with longish hair laughs then puts his h...,36
56,https://38.media.tumblr.com/819a74539497ae1899...,a man are doing acrobatic stunts on the wing o...,37


In [79]:
cummulativeData.loc(13123)

ValueError: No axis named 13123 for object type DataFrame

In [None]:
cummulativeData(13123)

ValueError: No axis named 13123 for object type DataFrame

In [73]:
cummulativeIndices[:5]

[11, 15, 34, 40, 56]

In [None]:
for i in cummulativeIndices:
    if i == 13123 :
        print()

13123


In [69]:
def plotFrameCount(dataset):
    frameFreq = {}

    for i in dataset['frames']:
        if i in frameFreq:
            frameFreq[i] += 1
        else:
            frameFreq[i] = 1

    plt.figure(figsize=(5, 2))
    plt.bar(frameFreq.keys(), frameFreq.values())
    plt.xlabel('Frame Count')
    plt.ylabel('Frequency')
    plt.title('Frequency Frame')
    plt.show()


In [15]:
BATCH_SIZE = 2
codeBookdim = 1024
embedDim = 256
hiddenDim = 512
inChannels = 3
tranform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])
torchDataset = FrameDataset(dataset, transform=tranform)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataloader = DataLoader(torchDataset, batch_size=BATCH_SIZE, shuffle = True)
modelA = VecQVAE(inChannels = inChannels, hiddenDim = hiddenDim, codeBookdim = codeBookdim, embedDim = embedDim).to(device)
lossFn = nn.MSELoss()
optimizerA = torch.optim.Adam(
                [
                    {'params': modelA.parameters(), 'lr': 2e-4},
                    # {'params': modelA.decodeImage.parameters(), 'lr': 2e-4},
                    # {'params': modelA.vector_quantize.parameters(), 'lr': 1e-4}
                ], weight_decay=1e-5)
schedulerA = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizerA, T_0=10, T_mult=2, eta_min=1e-6
            )

epochs = 1000

In [16]:
def perceptualLoss(pred, target):
    vgg = vgg16(pretrained = True).features[:17].eval()
    vgg.to(device)
    for param in vgg.parameters():
        param.requires_grad = False

    # print(pred.shape)
    batch, channels, height, width = pred.shape

    pred = pred.view(batch, channels, height, width)
    target = target.view(batch, channels, height, width)

    if pred.shape[1] == 1:
        pred = pred.repeat(1, 3, 1, 1)
        target = target.repeat(1, 3, 1, 1)


    vgg_pred = vgg(pred).to(device)
    vgg_true = vgg(target).to(device)

    perceptualoss = Fn.mse_loss(vgg_pred, vgg_true)
    return perceptualoss

pred = torch.randn(1, 10, 3, 10, 10)
pred1 = torch.randn(1, 10, 3, 10, 10)
out = perceptualLoss(pred, pred1)
out.item()



ValueError: too many values to unpack (expected 4)

In [None]:
def lab_color_loss(pred, target):
    pred_lab = kornia.color.rgb_to_lab(pred)
    target_lab = kornia.color.rgb_to_lab(target)
    loss = Fn.mse_loss(pred_lab, target_lab)
    return loss

pred = torch.randn(1, 3, 10, 10)
pred1 = torch.randn(1, 3, 10, 10)
out = lab_color_loss(pred, pred1)
out.item()

19654.966796875

In [None]:
# modelValA = torch.load("./projects/t2v-gif/models/VQVAE-GIF.pt", map_location=torch.device('cpu'))
# modelA.load_state_dict(modelValA)

for each_epoch in range(epochs):
    modelA.train()
    reconstruct_loss = 0.0
    codeb_loss = 0.0
    commit_loss = 0.0
    vqvaeloss = 0.0
    diverse_loss = 0.0
    ssim_loss = 0.0
    
    loop = tqdm(dataloader, f"{each_epoch}/{epochs}")
    perplexities = []

    for X, caption in loop:
        X = X.to(device)
        # Y = Y.to(device)
        
        quantized_latents, decoderOut, codebook_loss, commitment_loss, encoding_indices, perplexity, diversity_loss = modelA(X)
        
        # print(X.shape, decoderOut.shape)
        X = rearrange(X, 'b t d h w -> (b t) d h w', b = BATCH_SIZE, t = 40, d = 3, h = 128, w = 128)
        
        ssim_score = ssim(X, decoderOut, data_range=1.0)
        ssim_loss = 1.0 - ssim_score

        # reconstruction_loss = torch.mean((X - decoderOut)**2)
        reconstruction_loss = Fn.l1_loss(decoderOut, X)
        colorLoss = lab_color_loss(decoderOut, X)
        perceptualoss = perceptualLoss(decoderOut, X)
        
        loss = reconstruction_loss + codebook_loss + 0.2 * commitment_loss + 0.1 * diversity_loss + 0.1 * ssim_loss + 0.1 * perceptualoss + 0.1 * colorLoss
        vqvaeloss += loss.item()

        
        reconstruct_loss += reconstruction_loss.item()
        diverse_loss += diversity_loss.item()
        codeb_loss += codebook_loss.item()
        commit_loss += commitment_loss.item()
        ssim_loss += ssim_loss.item()
        perplexities.append(perplexity)
        
        
        optimizerA.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(modelA.parameters(), max_norm=1.0)
        optimizerA.step()
        loop.set_postfix({
            "TotalL": f"{vqvaeloss}", 
            "ReconsL": f"{reconstruct_loss}", 
            "CodeL":f"{codeb_loss}",
            "CommitL":f"{commitment_loss}", 
            "Perplexity":f"{perplexity}", 
            "Diversity Loss":f"{diverse_loss}", 
            "SSIM Loss":f"{ssim_loss}",
            "Perceptual Loss":f"{perceptualoss}",
            "Color Loss":f"{colorLoss}"
        })
    #     break
    # break

    average_perplexity = sum(perplexities)/len(perplexities)
    vqvaeloss /= len(dataloader)   
    reconstruct_loss /= len(dataloader)   
    codeb_loss /= len(dataloader)   
    commit_loss /= len(dataloader)   
    diverse_loss /= len(dataloader)
    perceptualoss /= len(dataloader)
    colorLoss /= len(dataloader)
    torch.save(modelA.state_dict(), "./models/VQVAE-GIF.pt")
    wandb.log({
        "VQVAE LR": optimizerA.param_groups[0]['lr'],
        "VQVAE Loss": vqvaeloss,
        "Reconstruction Loss": reconstruct_loss,
        "Codebook Loss": codeb_loss,
        "Commitment Loss": commit_loss,
        "Diversity Loss": diverse_loss,
        "Perplexity": average_perplexity,
        "SSIM Loss":ssim_loss,
        "Perceptual Loss":perceptualoss,
        "Color Loss":colorLoss
    })
    schedulerA.step()


0/1000:   0%|          | 0/5000 [00:00<?, ?it/s]

: 