In [1]:
# !unzip /content/random-frames-ucf-101.zip

In [2]:
# import torch.nn as nn

# class VGGReconstructor(nn.Module):

#   def __init__(self):
#     super(VGGReconstructor, self).__init__()

#     self.block1 = nn.Sequential(
#         nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, stride=1, padding=1),
#         nn.ReLU(),
#         nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1),
#         nn.BatchNorm2d(64),
#         nn.ReLU()
#     )

#     self.block2 = nn.Sequential(
#         nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
#         nn.BatchNorm2d(128),
#         nn.ReLU(),
#         nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1),
#         nn.BatchNorm2d(128),
#         nn.ReLU()        
#     )

#     self.block3 = nn.Sequential(
#         nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
#         nn.BatchNorm2d(256),
#         nn.ReLU(),
#         nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
#         nn.BatchNorm2d(256),
#         nn.ReLU(),
#         nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1),
#         nn.BatchNorm2d(256),
#         nn.ReLU()
#     )

#     self.block4 = nn.Sequential(
#         nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
#         nn.BatchNorm2d(512),
#         nn.ReLU(),
#         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
#         nn.BatchNorm2d(512),
#         nn.ReLU(),
#         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1),
#         nn.BatchNorm2d(512),
#         nn.ReLU()        
#     )

#     self.block5 = nn.Sequential(
#         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
#         nn.BatchNorm2d(512),
#         nn.ReLU(),
#         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
#         nn.BatchNorm2d(512),
#         nn.ReLU(),
#         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
#         nn.BatchNorm2d(512),
#         nn.ReLU(),
#         nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2, padding=1),
#         nn.BatchNorm2d(512),
#         nn.ReLU()
#     )

#   def forward(self, x_in):
#         out1 = self.block1(x_in)
#         out2 = self.block2(out1)
#         out3 = self.block3(out2)
#         out4 = self.block4(out3)
#         out5 = self.block5(out4)

#         return out5

In [3]:
import numpy as np # type: ignore
import os
import pandas as pd # type: ignore
from tqdm import tqdm # type: ignore
import natsort
import torch.nn as nn
import glob
import torch
from torch.nn import DataParallel as DDP
import torch.nn as nn # type: ignore
import torch.nn.functional as F # type: ignore
import torch.optim as optim # type: ignore
from torch.utils.data import Dataset, DataLoader # type: ignore
import matplotlib.pyplot as plt # type: ignore
import torchvision
import torchvision.models as models # type: ignore
from torchvision.models import VGG16_Weights# type: ignore
from torch.amp import autocast, GradScaler

import torchvision.transforms as transforms # type: ignore
from PIL import Image # type: ignore


class VGG16FeatureExtractor:
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.model = DDP(VGGReconstructor()).to(device)
        self.model.load_state_dict(torch.load('/kaggle/input/pre-trained-vgg-for-colouring-gan/pytorch/default/1/best_model.pth', weights_only=True), strict=False)
        self.model.eval()
        self.device = device

    def extract_features(self, image):
        image = image.to(self.device)
        with torch.no_grad():
            features = self.model(image)
        return features

class CustomDataSet(Dataset):
    def __init__(self, root, transform):
        self.main_dir = root
        self.transform = transform
        all_imgs = os.listdir(root)
        self.total_imgs = natsort.natsorted(all_imgs)

    def __len__(self):
        return len(self.total_imgs)

    def __getitem__(self, idx):
        img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
        image = Image.open(img_loc).convert("RGB")
        tensor_image = self.transform(image)
        return tensor_image, self.total_imgs[idx]
        
def read_keyframe_file(csv_path):
    """Read the keyframe file with timestamp mappings."""
    try:
        with open(csv_path, 'r') as f:
            lines = f.readlines()

        data = []
        for line in lines:
            parts = line.strip().split()
            keyframe_num = int(parts[1])
            timestamp = float(parts[4])
            data.append({'keyframe': keyframe_num, 'timestamp': timestamp})

        return pd.DataFrame(data)
    except Exception as e:
        print(f"Error reading file {csv_path}: {str(e)}")
        return None

# def process_data(base_dir):
#     all_features={}
#     feature_extractor = VGG16FeatureExtractor()
#     for img in tqdm(os.listdir(base_dir)):
#       try:
#         frame_path = os.path.join(base_dir, img)
#         features  = feature_extractor.extract_features(frame_path)
#         all_features[img] = {
#             'features': features
#         }
#       except Exception as e:
#         print(f"Error processing {img}: {str(e)}")
#         continue
#     return all_features

# Model architecture
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class ColorizationUNetWithAttentionDecoder(nn.Module):
    def __init__(self, loaded_dataset, input_channels=1):
        super().__init__()

        # Encoder with 1x1 convolutions
        self.enc1 = nn.Sequential(
            nn.Conv2d(input_channels, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        self.enc2 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.enc3 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        # Middle layer
        self.middle = nn.Sequential(
            nn.Conv2d(1024, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        # Decoder
        self.dec3 = nn.Sequential(
            nn.Conv2d(1088, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )

        self.dec2 = nn.Sequential(
            nn.Conv2d(256 + 128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )

        self.dec1 = nn.Sequential(
            nn.Conv2d(512 + 256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )

        #Attention Decoder layers
        self.layer1a = nn.Sequential(
            # nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels = 512, out_channels = 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU()
        )

        self.layer1b = nn.Sequential(
            # nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU()
        )

        self.layer2 = nn.Sequential(
            # nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels = 256, out_channels = 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU()
        )

        self.layer3 = nn.Sequential(
            # nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels = 128, out_channels = 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU()
        )

        self.layer4 = nn.Sequential(
            # nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels = 384, out_channels = 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU()
        )

        self.final_layer = nn.Sequential(
            # nn.ReflectionPad2d(1),
            nn.Conv2d(in_channels = 128, out_channels = 1, kernel_size=3, stride=1, padding=1)
        )

        # Output layers
        self.final_color = nn.Conv2d(512, 3, kernel_size=3, stride=1, padding=1)
        self.saliency_dec3 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1)
        self.saliency_dec2 = nn.Conv2d(256, 1, kernel_size=3, stride=1, padding=1)
        self.saliency_dec1 = nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)
        
        self.loaded_dataset = loaded_dataset

        # self.helper_model = VGG16FeatureExtractor()
        
    def forward(self, x, img_name):
        # Encoder path
        enc1 = self.enc1(x)
        enc2 = self.enc2(enc1)
        enc3 = self.enc3(enc2)
        # print("zzzzzzzzzz")
        # print(img_name)
        features_list = []
        for name in img_name:
            # Get the actual features for this name
            feature = self.loaded_dataset.data[name]
            # Convert to tensor and ensure correct shape
            feature_tensor = torch.from_numpy(feature).float()
            if len(feature_tensor.shape) == 2:
                feature_tensor = feature_tensor.reshape(1, *feature_tensor.shape)
            features_list.append(feature_tensor)
        
        # Stack the features
        from_vgg = torch.stack(features_list)
        
        # Ensure batch sizes match
        batch_size = enc3.size(0)
        if from_vgg.size(0) != batch_size:
            if from_vgg.size(0) < batch_size:
                # Repeat last feature if needed
                repeats = from_vgg[-1:].repeat(batch_size - from_vgg.size(0), 1, 1, 1)
                from_vgg = torch.cat([from_vgg, repeats], dim=0)
            else:
                # Take only what's needed
                from_vgg = from_vgg[:batch_size]
        from_vgg = torch.nn.functional.interpolate(from_vgg, size=enc3.shape[2:], mode="bilinear", align_corners=False)
        # from_vgg = self.helper_model.extract_features(x)
        # enc3 = nn.functional.interpolate(enc3, size=from_vgg[1], mode="bilinear", align_corners=False)
        # print(enc3.shape)
        # print("QQQQQQQQq")
        # print(from_vgg.shape)
        from_vgg = from_vgg.to(enc3.device)
        # print("bbbbbb")
        enc3_new = torch.cat([enc3, from_vgg], dim=1)
        # print(f"enc3 shape: {enc3.shape}")
        # print(f"from_vgg shape: {from_vgg.shape}")
        # print(f"enc3_new shape after cat: {enc3_new.shape}")
        # Middle
        middle = self.middle(enc3_new)
        # print("AAAAAAAA")

        # Decoder path with skip connections
        dec3 = self.dec3(torch.cat([middle, enc3_new], dim=1))
        # print("BBBBBBBB")
        dec2 = self.dec2(torch.cat([dec3, enc2], dim=1))
        # print("CCCCCCCC")
        dec1 = self.dec1(torch.cat([dec2, enc1], dim=1))
        # print(dec1.shape, dec2.shape, dec3.shape)

        # Output of Colorization Network
        color_output = self.final_color(dec1)
                # saliency3 = self.saliency_dec3(dec3) saliency2 = self.saliency_dec2(dec2) saliency1 = self.saliency_dec1(dec1)
        # print("DDDDDDDDDD")
        #Attention Decoder
        # att1 = nn.functional.interpolate(dec1, scale_factor=2, mode='nearest')
        att1=dec1
        # print("Before" + str(att1.shape))
        att1 = self.layer1a(att1)
        # print("EEEEEE")
        # att1 = nn.functional.interpolate(att1, scale_factor=1, mode='nearest')
        x1 = self.layer1b(att1)
        # print("After" + str(x1.shape))
        # x1 = nn.functional.interpolate(x1, scale_factor=9/11, mode='nearest')

        # att2 = nn.functional.interpolate(dec2, scale_factor=2, mode='nearest')
        att2=dec2
        x2 = self.layer2(att2)
        # print("FFFFFFf")
        x3 = self.layer3(dec3)

        ############################
        #        DEBUG
        # print(x1.shape, x2.shape, x3.shape)
        ##################################
        x2 = x2.to(x1.device)
        x3 = x3.to(x1.device)
        x = torch.cat((x1, x2, x3), 1)
        # x = nn.functional.interpolate(x, scale_factor=2)
        x = self.layer4(x)

        saliency_out = self.final_layer(x)
        # print("saliency_out" + str(saliency_out.shape))
        return color_output, saliency_out

# Modify the Dataset class to ensure correct tensor dimensions
# class VideoFeaturesDataset(Dataset):
#     def __init__(self, npz_path):
#         self.npz_path = npz_path
#         self.data = np.load(self.npz_path, mmap_mode='r')  # Memory-map for efficient access
#         self.keys = list(self.data.files)  # Store keys for indexing

#     def __len__(self):
#         return len(self.keys)

#     # def __getitem__(self, idx):
#     #     key = self.keys[idx]
#     #     features = self.data[key]

#     #     # Ensure features have the correct shape (C, H, W)
#     #     if len(features.shape) == 2:
#     #         features = features.reshape(1, *features.shape)
#     #     elif len(features.shape) == 3:
#     #         features = features.reshape(-1, features.shape[-2], 1, 1)

#     def __getitem__(self, idx):
#         try:
#             key = self.keys[idx]
#             features = self.data[key]
#             # print(features.shape)
#             # Ensure features have the correct shape (C, H, W)
#             if len(features.shape) == 2:
#                 features = features.reshape(1, *features.shape)
#             # elif len(features.shape) == 3:
#             #     features = features.reshape(7, -1, features.shape[-1])

#             # self.features.extend([torch.from_numpy(feat) for feat in features])
#             # print(features.shape)
#             return features

#         except Exception as e:
#             print(f"Error in __getitem__ for idx {idx}: {e}")
#             return None  # Return None explicitly to trigger error

# Training utilities
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def save_checkpoint(state, filename):
    torch.save(state, filename)
    print(f"Checkpoint saved: {filename}")

def load_checkpoint(model, optimizer, filename):
    if os.path.isfile(filename):
        checkpoint = torch.load(filename, weights_only=True)
        model.module.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_G'])
        return checkpoint['epoch']
    return 0

# def train_epoch(model, train_loader, criterion, optimizer, device):
#     model.train()
#     losses = AverageMeter()

#     with tqdm(train_loader, desc="Training") as pbar:
#         for features in pbar:
#             features = features.to(device)

#             optimizer.zero_grad()
#             color_output, saliency_map = model(features)

#             color_target = features[:, :2, :, :]
#             color_loss = criterion(color_output, color_target)

#             # Calculate saliency losses
#             saliency_target = torch.mean(features, dim=1, keepdim=True)  # Create target from input features
#             saliency_loss = criterion(saliency_map, saliency_target)

#             # Total loss
#             loss = color_loss + saliency_loss

#             loss.backward()
#             optimizer.step()

#             losses.update(loss.item(), features.size(0))
#             pbar.set_postfix({'Loss': f'{losses.avg:.4f}'})

#     return losses.avg

# def validate(model, val_loader, criterion, device):
#     model.eval()
#     losses = AverageMeter()

#     with torch.no_grad():
#         for features in tqdm(val_loader, desc="Validating"):
#             features = features.to(device)

#             color_output, saliency_map = model(features)

#             # Calculate color loss
#             color_target = features[:, :2, :, :]
#             color_loss = criterion(color_output, color_target)

#             # Calculate saliency losses
#             saliency_target = torch.mean(features, dim=1, keepdim=True)
#             saliency_loss = criterion(saliency_map, saliency_target)

#             # Total loss
#             loss = color_loss + saliency_loss

#             losses.update(loss.item(), features.size(0))

#     return losses.avg


In [4]:
class Discriminator(nn.Module):
    def __init__(self, input_channels):
        super().__init__()

        self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)

        # Use global pooling to handle 1x1 inputs
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        # Final layer
        self.fc = nn.Conv2d(512, 1, kernel_size=1)

        self.leaky_relu = nn.LeakyReLU(0.2, inplace=True)
        self.batch_norm2 = nn.BatchNorm2d(128)
        self.batch_norm3 = nn.BatchNorm2d(256)
        self.batch_norm4 = nn.BatchNorm2d(512)

        self.dropout = nn.Dropout(p=0.3)

    def forward(self, x):
        x = self.dropout(self.leaky_relu(self.conv1(x)))
        x = self.dropout(self.leaky_relu(self.batch_norm2(self.conv2(x))))
        x = self.dropout(self.leaky_relu(self.batch_norm3(self.conv3(x))))
        x = self.dropout(self.leaky_relu(self.batch_norm4(self.conv4(x))))

        # Global pooling to handle small spatial dimensions
        x = self.global_pool(x)

        # Final prediction
        x = self.fc(x)
        # x = torch.sigmoid(x)
        return x

class GANLoss_autocast_compatible:
    def __init__(self, device):
        self.register_buffer = lambda name, tensor: setattr(self, name, tensor)
        self.register_buffer('real_label', torch.tensor(1.0).to(device))
        self.register_buffer('fake_label', torch.tensor(0.0).to(device))
        self.loss = nn.BCEWithLogitsLoss()
        self.device = device

    # def get_target_tensor(self, prediction, target_is_real):
    #     if target_is_real:
    #         target_tensor = self.real_label
    #     else:
    #         target_tensor = self.fake_label
    #     return target_tensor.expand_as(prediction)

    def get_target_tensor(self, prediction, target_value):
        target_tensor = torch.full_like(prediction, target_value, device=self.device)
        return target_tensor

    # def __call__(self, prediction, target_is_real):
    #     target_tensor = self.get_target_tensor(prediction, target_is_real)
    #     return self.loss(prediction, target_tensor)

    def __call__(self, prediction, target_value):
        target_tensor = self.get_target_tensor(prediction, target_value)
        return self.loss(prediction, target_tensor)

def load_checkpoint(model, optimizer, filename):
    if os.path.isfile(filename):
        checkpoint = torch.load(filename, weights_only=True)
        model.module.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_G'])
        return checkpoint['epoch']
    return 0

def train_gan_epoch(generator, color_discriminator, attention_discriminator,
                   train_loader, gan_criterion_d, optimizer_G, optimizer_D_color,
                   optimizer_D_attention, device, scaler, scaler2, Grayscale_creator, epoch):

    generator.train()
    color_discriminator.train()
    attention_discriminator.train()

    losses_G = AverageMeter()
    losses_D_color = AverageMeter()
    losses_D_attention = AverageMeter()

    with tqdm(train_loader, desc="Training GAN") as pbar:
        for features, img_name in pbar:
            batch_size = features.size(0)
            features = features.to(device)

            # Ground truth labels
            # real_color = features[:, :2, :, :]
            real_color = features[:, :, :]
            real_saliency = torch.mean(features, dim=1, keepdim=True)

            noise_factor = max(0.01, 0.1 * (0.99 ** epoch))
            real_color_noisy = real_color + noise_factor * torch.randn_like(real_color)

            #Train Discriminator
            optimizer_D_color.zero_grad()
            optimizer_D_attention.zero_grad()

            # Generate fake outputs
            # print("here")
            # print(features.shape)
            with autocast(device_type="cuda"):
                gray_input = Grayscale_creator(features)
                gray_input_noisy = gray_input + noise_factor * torch.randn_like(gray_input)
                fake_color, fake_saliency = generator(gray_input, img_name)

                # Color discriminator
                pred_real_color = color_discriminator(real_color)
                loss_D_real_color = gan_criterion_d(pred_real_color, 0.9)
    
                pred_fake_color = color_discriminator(fake_color.detach())
                loss_D_fake_color = gan_criterion_d(pred_fake_color, False)
    
                loss_D_color = (loss_D_real_color + loss_D_fake_color) * 0.5
                
            # scaler.scale(loss_D_color).backward()

            # with autocast():
                
                # Attention discriminator
                real_weighted = real_color * real_saliency
                fake_weighted = fake_color.detach() * fake_saliency.detach()
    
                pred_real_attention = attention_discriminator(real_weighted)
                loss_D_real_attention = gan_criterion_d(pred_real_attention, 0.9)
    
                pred_fake_attention = attention_discriminator(fake_weighted)
                loss_D_fake_attention = gan_criterion_d(pred_fake_attention, False)

                loss_D_attention = (loss_D_real_attention + loss_D_fake_attention) * 0.5
                
            scaler.scale(loss_D_attention + loss_D_color).backward()

            scaler.unscale_(optimizer_D_color)
            torch.nn.utils.clip_grad_norm_(color_discriminator.parameters(), max_norm=1.0)
            # scaler.step(optimizer_D_color)
            # scaler.update()
            scaler.unscale_(optimizer_D_attention)
            torch.nn.utils.clip_grad_norm_(attention_discriminator.parameters(), max_norm=1.0)
            scaler.step(optimizer_D_color)
            scaler.step(optimizer_D_attention)
            scaler.update()

            #Train generators
            optimizer_G.zero_grad()

            with autocast(device_type="cuda"):
                # Color GAN loss
                pred_fake_color = color_discriminator(fake_color)
                loss_G_color = gan_criterion_d(pred_fake_color, 0.9)
    
                # Attention GAN loss
                fake_weighted = fake_color * fake_saliency
                pred_fake_attention = attention_discriminator(fake_weighted)
                loss_G_attention = gan_criterion_d(pred_fake_attention, 0.9)
    
                # L1 losses
                loss_L1_color = F.l1_loss(fake_color, real_color)
                loss_L1_saliency = F.l1_loss(fake_saliency, real_saliency)
    
                # Combined generator loss
                loss_G = (loss_G_color + loss_G_attention +
                         1.0 * loss_L1_color + 0.5 * loss_L1_saliency)
                
            scaler2.scale(loss_G).backward()

            scaler2.unscale_(optimizer_G)
            # torch.nn.utils.clip_grad_norm_(generator.parameters(), max_norm=1.0)
            scaler2.step(optimizer_G)
            scaler2.update()

            # Update metrics
            losses_G.update(loss_G.item(), batch_size)
            losses_D_color.update(loss_D_color.item(), batch_size)
            losses_D_attention.update(loss_D_attention.item(), batch_size)

            pbar.set_postfix({
                'G_loss': f'{losses_G.avg:.4f}',
                'D_color_loss': f'{losses_D_color.avg:.4f}',
                'D_attention_loss': f'{losses_D_attention.avg:.4f}'
            })

    return losses_G.avg, losses_D_color.avg, losses_D_attention.avg

class VideoFeaturesDataset(Dataset):
    def __init__(self, npz_path):
        self.npz_path = npz_path
        self.data = np.load(self.npz_path, mmap_mode='r')  # Memory-map for efficient access
        self.keys = list(self.data.files)  # Store keys for indexing

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        try:
            if isinstance(idx, (tuple, list)):  # If batch, fetch each separately
                features = [self.__getitem__(i) for i in idx]
                return torch.stack(features)  # Stack into a single tensor
            # key = self.keys[idx]
            features = self.data[idx]
            # print(features.shape)
            # Ensure features have the correct shape (C, H, W)
            if len(features.shape) == 2:
                features = features.reshape(1, *features.shape)
            # elif len(features.shape) == 3:
            #     features = features.reshape(7, -1, features.shape[-1])

            # self.features.extend([torch.from_numpy(feat) for feat in features])
            # print(features.shape)
            return torch.from_numpy(features).float()  # Convert to tensor

        except Exception as e:
            print(f"Error in __getitem__ for idx {idx}: {e}")
            return None  # Return None explicitly to trigger error

In [5]:
def main():
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    batch_size = 1
    num_epochs = 3
    learning_rate = 0.0002
    beta1 = 0.5
    beta2 = 0.999

    '''base_dir= "/kaggle/input/random-frames-ucf-101/Data"

    all_features = {}

    feature_extractor = VGG16FeatureExtractor()

    for img in tqdm(os.listdir(base_dir)):
      try:
        frame_path = os.path.join(base_dir, img)
        features  = feature_extractor.extract_features(frame_path)
        all_features[img] = np.vstack(features)
      except Exception as e:
        print(f"Error processing img: {str(e)}")
        continue

    if all_features:
      output_path = '/kaggle/working/extracted_features.npz'
      np.savez_compressed(
          output_path,
          **{f"{img}": all_features[img]
              for img in all_features.keys()}
      )
      print(f"Features saved to {output_path}")
    else:
        print("No features were extracted successfully")'''

    # Load data
    # data = np.load('/kaggle/input/featurevgg-npz/extracted_features.npz')
    # features_dict = {}
    # timestamps_dict = {}

    # for key in data.files:
    #     features_dict[key] = data[key]

    # print(f"Number of videos loaded: {len(features_dict)}")

    # Create dataset and dataloaders
    # dataset = VideoFeaturesDataset('/kaggle/input/featurevgg-npz/extracted_features.npz')

    transform_image = transforms.Compose([
            transforms.Resize(128),
            transforms.CenterCrop(128),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

    dataset_dir = '/kaggle/input/gallery-keyframes/keyframes-dataset'
    output_dir = '/kaggle/working/model_outputs'

    for type_folder in os.listdir(dataset_dir):
        type_path = os.path.join(dataset_dir, type_folder)
        if not os.path.isdir(type_path):
            continue

        output_type_dir = os.path.join(output_dir, type_folder)
        os.makedirs(output_type_dir, exist_ok=True)

        for vid_folder in os.listdir(type_path):
            torch.cuda.empty_cache()
            if vid_folder in ['BenchPress_v_BenchPress_g01_c01', 'ApplyLipstick_v_ApplyLipstick_g01_c01', 'BabyCrawling_v_BabyCrawling_g05_c01', 'CuttingInKitchen_v_CuttingInKitchen_g01_c01', 'CleanAndJerk_v_CleanAndJerk_g10_c03', 'Billiards_v_Billiards_g06_c02', 'BlowDryHair_v_BlowDryHair_g22_c01', 'BoxingSpeedBag_v_BoxingSpeedBag_g19_c02', 'BreastStroke_v_BreastStroke_g17_c02']:
                continue
            vid_path = os.path.join(type_path, vid_folder)
            if not os.path.isdir(vid_path):
                continue

            output_vid_dir = os.path.join(output_type_dir, vid_folder)
            os.makedirs(output_vid_dir, exist_ok=True)

            #################
            
            dataset = CustomDataSet(
            root=vid_path,
            transform=transform_image
            )
            
            print(f"Total number of samples: {len(dataset)}")
        
            # train_dataset, test_dataset = torch.utils.data.random_split(dataset, [int(0.8*len(dataset)), len(dataset)-int(0.8*len(dataset))])
            # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            test_loader = DataLoader(dataset, batch_size=len(dataset), shuffle=False)
            # print('AAAAAAA')
        
            # Initialize models
            sample_batch, img_name = next(iter(test_loader))
            # print(f"Here, {img_name}")
            # sample_batch = train_dataset
            input_channels = sample_batch[0].shape[0]
            print("input_channels = "+str(input_channels))
        
            l_d_path = os.path.join('/kaggle/input/feature-vectors-for-gallery-samples/model_outputs', type_folder, vid_folder, 'grayscale_extracted_features.npz')
            loaded_dataset = VideoFeaturesDataset(l_d_path)
            print(f"Total number of features loaded: {len(loaded_dataset)}")
            
            generator = DDP(ColorizationUNetWithAttentionDecoder(loaded_dataset, input_channels=1)).to(device)
            color_discriminator = DDP(Discriminator(input_channels=3)).to(device)
            attention_discriminator = DDP(Discriminator(input_channels=3)).to(device)
        
            # Initialize optimizers
            optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate*2, betas=(beta1, beta2), foreach=True, weight_decay=1e-6)
            optimizer_D_color = optim.Adam(color_discriminator.parameters(), lr=learning_rate*0.5, betas=(beta1, beta2), foreach=True, weight_decay=1e-4)
            optimizer_D_attention = optim.Adam(attention_discriminator.parameters(), lr=learning_rate*0.5, betas=(beta1, beta2), foreach=True, weight_decay=1e-4)
        
            load_checkpoint(generator, optimizer_G, '/kaggle/input/epoch-26-colorization-gan/pytorch/default/2/best_model (3).pth')
            # load_checkpoint(color_discriminator, optimizer_D_color, '/kaggle/input/coloring-gan-23rd-epoch-checkpoint/pytorch/default/1/latest_checkpoint_Epoch_23.pth')
            # load_checkpoint(attention_discriminator, optimizer_D_attention, '/kaggle/input/coloring-gan-23rd-epoch-checkpoint/pytorch/default/1/latest_checkpoint_Epoch_23.pth')
        
            scaler = GradScaler()
            scaler_g = GradScaler()
            Grayscale_creator = transforms.Grayscale(num_output_channels=1)
            
            # Loss functions
            gan_criterion_d = GANLoss_autocast_compatible(device)
        
            # Create output directory
            # output_dir = '/kaggle/working/model_outputs'
            # os.makedirs(output_dir, exist_ok=True)
        
            # Training loop
            best_loss = float('inf')
            start_epoch = 0
        
            # Load checkpoint if exists
            # checkpoint_path = os.path.join(output_dir, 'latest_checkpoint_Epoch_23.pth')
        
            num_inputs = len(dataset)
            generator.eval()
            with torch.no_grad():
                sample_features, img_name = next(iter(test_loader))
                sample_features = sample_features.to(device)
                Grayscale_creator = transforms.Grayscale(num_output_channels=1)
                gray_input = Grayscale_creator(sample_features)
                fake_color, fake_saliency = generator(gray_input, img_name)
        
                fake_color = fake_color.to('cpu')
                fake_saliency = fake_saliency.to('cpu')
                mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to('cpu')
                std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to('cpu')
        
                # Save colorization outputs
                # plt.figure(figsize=(15, 5))
                
                output_col_dir = os.path.join(output_vid_dir, 'Color')
                os.makedirs(output_col_dir, exist_ok=True)
                for i in range(len(dataset)):
                    out_denorm = fake_color[i] * std + mean
                    color_img = out_denorm.cpu().numpy()
                    color_img = np.transpose(color_img, (1, 2, 0))
                    color_img = np.clip(color_img, 0, 1)
        
                    # plt.subplot(1, 4, i+1)
                    plt.imshow(color_img)
                    plt.axis('off')
                    plt.savefig(os.path.join(output_col_dir, f'test_video_colorization_{i}.png'))
                    plt.close()
        
                # Save saliency maps
                # plt.figure(figsize=(15, 5))
                output_sal_dir = os.path.join(output_vid_dir, 'Saliency')
                os.makedirs(output_sal_dir, exist_ok=True)
                for i in range(len(dataset)):
                    out = fake_saliency[i].unsqueeze(0)
                    out_expanded = out.expand(-1, 3, -1, -1)  # Expand to 3 channels
                    out_denorm = out_expanded * std + mean
                    saliency = out_denorm.cpu().numpy()
                    saliency = np.transpose(saliency[0], (1, 2, 0))
                    saliency = np.clip(saliency, 0, 1)
        
                    # plt.subplot(1, 4, i+1)
                    plt.imshow(saliency)
                    plt.axis('off')
                plt.savefig(os.path.join(output_sal_dir, f'test_video_saliency{i}.png'))
                
                plt.close()
    '''if os.path.exists(checkpoint_path):
        start_epoch = load_checkpoint(generator, optimizer_G, checkpoint_path)
        print(f"Loaded checkpoint from epoch {start_epoch}")

    for epoch in range(start_epoch, num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        # Train
        g_loss, d_color_loss, d_attention_loss = train_gan_epoch(
            generator, color_discriminator, attention_discriminator,
            train_loader, gan_criterion_d, optimizer_G,
            optimizer_D_color, optimizer_D_attention, device, scaler, scaler_g, Grayscale_creator, epoch
        )

        print(f"Generator Loss: {g_loss:.4f}")
        print(f"Color Discriminator Loss: {d_color_loss:.4f}")
        print(f"Attention Discriminator Loss: {d_attention_loss:.4f}")

        # Save checkpoint
        is_best = g_loss < best_loss
        best_loss = min(g_loss, best_loss)

        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': generator.module.state_dict(),
            'optimizer_G': optimizer_G.state_dict(),
            'optimizer_D_color': optimizer_D_color.state_dict(),
            'optimizer_D_attention': optimizer_D_attention.state_dict(),
            'best_loss': best_loss,
        }, os.path.join(output_dir, 'latest_checkpoint.pth'))

        if is_best:
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': generator.module.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D_color': optimizer_D_color.state_dict(),
                'optimizer_D_attention': optimizer_D_attention.state_dict(),
                'best_loss': best_loss,
            }, os.path.join(output_dir, 'best_model.pth'))

        # Visualize results every 5 epochs
        if (epoch + 1) % 1 == 0:
            generator.eval()
            with torch.no_grad():
                sample_features, img_name = next(iter(test_loader))
                sample_features = sample_features.to(device)
                Grayscale_creator = transforms.Grayscale(num_output_channels=1)
                gray_input = Grayscale_creator(sample_features)
                fake_color, fake_saliency = generator(gray_input, img_name)

                
                mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(device)
                std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(device)

                # Save colorization outputs
                plt.figure(figsize=(15, 5))
                for i in range(4):
                    out_denorm = fake_color[i] * std + mean
                    color_img = out_denorm.cpu().numpy()
                    # color_img = (color_img - color_img.min()) / (color_img.max() - color_img.min())
                    # rgb_img = np.zeros((color_img.shape[1], color_img.shape[2], 3))
                    color_img = np.transpose(color_img, (1, 2, 0))
                    color_img = np.clip(color_img, 0, 1)

                    plt.subplot(1, 4, i+1)
                    plt.imshow(color_img)
                    plt.axis('off')
                plt.savefig(os.path.join(output_dir, f'updated_epoch_{epoch+1}_colorization.png'))
                plt.close()

                # Save saliency maps
                plt.figure(figsize=(15, 5))
                for i in range(4):
                    out = fake_saliency[i].unsqueeze(0)
                    out_expanded = out.expand(-1, 3, -1, -1)  # Expand to 3 channels
                    out_denorm = out_expanded * std + mean
                    saliency = out_denorm.cpu().numpy()
                    # saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min())
                    saliency = np.transpose(saliency[0], (1, 2, 0))
                    saliency = np.clip(saliency, 0, 1)

                    plt.subplot(1, 4, i+1)
                    plt.imshow(saliency)
                    plt.axis('off')
                plt.savefig(os.path.join(output_dir, f'updated_epoch_{epoch+1}_saliency.png'))
                
                plt.close()'''
    
    

if __name__ == "__main__":
    main()

Total number of samples: 22
input_channels = 3
Total number of features loaded: 22
Total number of samples: 12
input_channels = 3
Total number of features loaded: 12
Total number of samples: 30
input_channels = 3
Total number of features loaded: 30
Total number of samples: 24
input_channels = 3
Total number of features loaded: 24
Total number of samples: 2
input_channels = 3
Total number of features loaded: 2
Total number of samples: 6
input_channels = 3
Total number of features loaded: 6
Total number of samples: 46
input_channels = 3
Total number of features loaded: 46
Total number of samples: 20
input_channels = 3
Total number of features loaded: 20
Total number of samples: 8
input_channels = 3
Total number of features loaded: 8
Total number of samples: 10
input_channels = 3
Total number of features loaded: 10
Total number of samples: 40
input_channels = 3
Total number of features loaded: 40
Total number of samples: 40
input_channels = 3
Total number of features loaded: 40
Total numb

In [6]:
# s=4
# data = np.load('/kaggle/input/featurevgg-npz/extracted_features.npz', mmap_mode='r')
# keys = list(data.files)
# # key = keys[s]
# # features = data[key]
# # features = features.reshape(-1, features.shape[-2], 1, 1)
# # features.extend([torch.from_numpy(feat) for feat in features])
# x = data['64601.jpg']
# print(x.shape)
# x = x.reshape(7, -1, x.shape[-1])
# x.shape

In [7]:
# !rm -rf /kaggle/working/*

In [8]:
# import torch
# if torch.cuda.is_available():
#     print(f'CUDA is available. Number of GPUs: {torch.cuda.device_count()}')
# else:
#     print('CUDA is not available. Running on CPU.')

In [9]:
# param_size = 0
# for param in model.parameters():
#     param_size += param.nelement() * param.element_size()
# buffer_size = 0
# for buffer in model.buffers():
#     buffer_size += buffer.nelement() * buffer.element_size()

# size_all_mb = (param_size + buffer_size) / 2**20
# print('model size: {:.3f}MB'.format(size_all_mb))