In [4]:
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to c:\users\sahas\appdata\local\temp\pip-req-build-krn6ztjx
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
Building wheels for collected packages: clip
  Building wheel for clip (setup.py): started
  Building wheel for clip (setup.py): finished with status 'done'
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369570 sha256=7828b1d0e31496d666ebd4f5956d417d95a0825ac10fb5e20864f65dd82074c3
  Stored in directory: C:\Users\sahas\AppData\Local\Temp\pip-ephem-wheel-cache-ytoguyip\wheels\35\3e\df\3d24cbfb3b6a06f17a2bfd7d1138900d4365d9028aa8f6e92f
Successfully built clip
Installing collected pa

  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git 'C:\Users\sahas\AppData\Local\Temp\pip-req-build-krn6ztjx'


In [None]:
import os
import random
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from PIL import Image
import clip  # OpenAI CLIP library
from torch.utils.data import Subset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
coco_image = 'Dataset/coco_images/train2017'
coco_annotation = 'Dataset/coco_annotations/captions_train2017.json'

In [3]:
class COCODataset(Dataset):
    def __init__(self, image_dir, annotation_file, transform=None):
        import json, os
        self.image_dir = image_dir
        self.transform = transform
        
        # Load the COCO caption annotations
        with open(annotation_file, 'r') as f:
            data = json.load(f)
        self.annotations = data['annotations']
        
    def __len__(self):
        return len(self.annotations)
    
    def __getitem__(self, idx):
        ann = self.annotations[idx]
        caption = ann['caption']
        image_id = ann['image_id']
        
        # Construct the filename from the image_id (COCO images typically named as 12-digit ID)
        filename = f"{image_id:012d}.jpg"
        image_path = os.path.join(self.image_dir, filename)
        
        if not os.path.isfile(image_path):
            # If the file is missing, raise an exception or skip
            raise FileNotFoundError(f"Missing file: {image_path}")
        
        # Load the image
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        
        return image, caption

In [4]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Low resolution for demonstration
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

In [None]:
from torch.utils.data._utils.collate import default_collate

def collate_fn_filter_none(batch):
    # Filter out any None values from the batch
    batch = [sample for sample in batch if sample is not None]
    # Return an empty batch if all samples are None (handle as needed)
    if len(batch) == 0:
        return None
    return default_collate(batch)

# Create the dataset
dataset = COCODataset(image_dir=coco_image, annotation_file=coco_annotation, transform=transform)

# After creating your dataset object:
indices = list(range(len(dataset)))
random.shuffle(indices)

# Use only 10k images for a quicker experiment
subset_indices = indices[:10000]

small_dataset = Subset(dataset, subset_indices)

dataloader = DataLoader(small_dataset, batch_size=64, shuffle=True, num_workers=4, collate_fn=collate_fn_filter_none)


In [6]:
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)

100%|███████████████████████████████████████| 338M/338M [00:55<00:00, 6.39MiB/s]


In [7]:
class Generator(nn.Module):
    def __init__(self, noise_dim=100, text_dim=512, feature_dim=64):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(noise_dim + text_dim, feature_dim * 8 * 4 * 4),
            nn.BatchNorm1d(feature_dim * 8 * 4 * 4),
            nn.ReLU(True)
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(feature_dim * 8, feature_dim * 4, 4, 2, 1, bias=False),  # 4x4 -> 8x8
            nn.BatchNorm2d(feature_dim * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(feature_dim * 4, feature_dim * 2, 4, 2, 1, bias=False),  # 8x8 -> 16x16
            nn.BatchNorm2d(feature_dim * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(feature_dim * 2, feature_dim, 4, 2, 1, bias=False),        # 16x16 -> 32x32
            nn.BatchNorm2d(feature_dim),
            nn.ReLU(True),
            nn.ConvTranspose2d(feature_dim, 3, 4, 2, 1, bias=False),                      # 32x32 -> 64x64
            nn.Tanh()
        )

    def forward(self, noise, text_embed):
        x = torch.cat([noise, text_embed], dim=1)
        x = self.fc(x)
        x = x.view(x.size(0), -1, 4, 4)
        x = self.deconv(x)
        return x

In [8]:
class Discriminator(nn.Module):
    def __init__(self, text_dim=512, feature_dim=64):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, feature_dim, 4, 2, 1, bias=False),  # 64x64 -> 32x32
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_dim, feature_dim * 2, 4, 2, 1, bias=False),  # 32x32 -> 16x16
            nn.BatchNorm2d(feature_dim * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_dim * 2, feature_dim * 4, 4, 2, 1, bias=False),  # 16x16 -> 8x8
            nn.BatchNorm2d(feature_dim * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(feature_dim * 4, feature_dim * 8, 4, 2, 1, bias=False),  # 8x8 -> 4x4
            nn.BatchNorm2d(feature_dim * 8),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.fc = nn.Linear(feature_dim * 8 * 4 * 4 + text_dim, 1)

    def forward(self, img, text_embed):
        x = self.conv(img)
        x = x.view(x.size(0), -1)
        x = torch.cat([x, text_embed], dim=1)
        x = self.fc(x)
        return x

In [9]:
noise_dim = 100
text_dim = 512
G = Generator(noise_dim=noise_dim, text_dim=text_dim).to(device)
D = Discriminator(text_dim=text_dim).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizerG = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerD = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
num_epochs = 50

for epoch in range(num_epochs):
    for i, (images, captions) in enumerate(dataloader):
        batch_size = images.size(0)
        images = images.to(device)
        
        # Encode captions using CLIP
        text_tokens = clip.tokenize(captions).to(device)
        with torch.no_grad():
            text_embeddings = model_clip.encode_text(text_tokens)  # shape: (batch, 512)
        
        # Labels for real and fake images
        real_labels = torch.ones(batch_size, 1, device=device)
        fake_labels = torch.zeros(batch_size, 1, device=device)
        
        # ------------- Update Discriminator -------------
        D.zero_grad()
        # Real images loss
        outputs_real = D(images, text_embeddings)
        d_loss_real = criterion(outputs_real, real_labels)
        d_loss_real.backward()
        
        # Generate fake images
        noise = torch.randn(batch_size, noise_dim, device=device)
        fake_images = G(noise, text_embeddings)
        outputs_fake = D(fake_images.detach(), text_embeddings)
        d_loss_fake = criterion(outputs_fake, fake_labels)
        d_loss_fake.backward()
        
        optimizerD.step()
        d_loss = d_loss_real + d_loss_fake
        
        # ------------- Update Generator -------------
        G.zero_grad()
        outputs_fake_for_G = D(fake_images, text_embeddings)
        g_loss = criterion(outputs_fake_for_G, real_labels)
        g_loss.backward()
        optimizerG.step()
        
        if i % 100 == 0:
            print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(dataloader)}], "
                  f"D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
    
    # Save a grid of generated images and model checkpoints after each epoch
    fake_images_sample = fake_images[:16].detach().cpu()
    grid = utils.make_grid(fake_images_sample, nrow=4, normalize=True)
    utils.save_image(grid, f"generated_epoch_{epoch+1}.png")
    torch.save(G.state_dict(), f"generator_epoch_{epoch+1}.pth")
    torch.save(D.state_dict(), f"discriminator_epoch_{epoch+1}.pth")