In [None]:
!pip install torch torchvision Pillow transformers datasets requests fsspec git+https://github.com/openai/CLIP.git

In [None]:
import torch
import clip
from PIL import Image
from torchvision import transforms
from torch.utils.data import DataLoader
from datasets import load_dataset
import torch.nn as nn
import torch.optim as optim
from huggingface_hub import login


HUGGINGFACE_TOKEN = "give your hf auth"
login(token=HUGGINGFACE_TOKEN)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)
dataset = load_dataset("HuggingFaceM4/COCO", "2014", split="train")


In [None]:
dataset = dataset.map(transform_dataset, remove_columns=[col for col in dataset.column_names if col != "image" and col != "captions"], batched=False)
dataset = dataset.filter(lambda x: x is not None)

In [None]:
def collate_fn(batch):
    images = []
    text_captions = []
    for item in batch:
        if "image" in item and item["image"] is not None and "captions" in item:
            image = item["image"]
            if not isinstance(image, torch.Tensor):
                try:
                    image = torch.tensor(image)
                except (TypeError, ValueError) as e:
                    print(f"Error converting image to tensor: {e}")
                    continue
            images.append(image)
            text_captions.extend(item["captions"])

    if not images:
        return {}

    try:
        image_batch = torch.stack(images)
        text_tokens = clip.tokenize(text_captions).to(device)
        return {"image": image_batch, "captions": text_tokens}
    except TypeError:
        print("Error stacking images. Inspect batch content:")
        for image in images:
            print(f"Shape: {image.shape}, dtype: {image.dtype}")
        return {}

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)


In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, img_channels, text_embedding_dim):
        super(Generator, self).__init__()
        self.text_embedding_dim = text_embedding_dim
        self.model = nn.Sequential(
            nn.Linear(latent_dim + text_embedding_dim, 256),
            nn.ReLU(),
            nn.Linear(256, img_channels * 64 * 64),
            nn.Tanh(),
        )

    def forward(self, z, text_embedding):
        combined_input = torch.cat((z, text_embedding), dim=1)
        return self.model(combined_input).view(z.size(0), 3, 64, 64)

class Discriminator(nn.Module):
    def __init__(self, img_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_channels * 64 * 64, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        return self.model(img.view(img.size(0), -1))

latent_dim = 100
img_channels = 3
lr = 0.0002
epochs = 5
text_embedding_dim = 512  
generator = Generator(latent_dim, img_channels, text_embedding_dim).to(device)
discriminator = Discriminator(img_channels).to(device)
optimizer_g = optim.Adam(generator.parameters(), lr=lr)
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr)
criterion = nn.BCELoss()


In [None]:
for epoch in range(epochs):
    for batch in dataloader:
        if not batch:
            continue
        real_images = batch["image"].to(device)
        batch_size = real_images.size(0)
        optimizer_d.zero_grad()
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        real_loss = criterion(discriminator(real_images), real_labels)

        z = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(z, torch.randn(batch_size, text_embedding_dim).to(device)) #Random text embeddings for now.
        fake_loss = criterion(discriminator(fake_images.detach()), fake_labels)

        d_loss = real_loss + fake_loss
        d_loss.backward()
        optimizer_d.step()
        optimizer_g.zero_grad()
        z = torch.randn(batch_size, latent_dim).to(device)

        
        prompts = ["a cat", "a dog", "a bird", "a car"] * (batch_size // 4)
        prompts = prompts[:batch_size]
        text_tokens = clip.tokenize(prompts).to(device)
        with torch.no_grad():
            text_embedding = model.encode_text(text_tokens).float()

        fake_images = generator(z, text_embedding)

        g_loss = criterion(discriminator(fake_images), real_labels)
        g_loss.backward()
        optimizer_g.step()

    print(f"Epoch [{epoch+1}/{epochs}], d_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}")




In [None]:
z = torch.randn(1, latent_dim).to(device)
prompt = ["a beautiful landscape"]
text_tokens = clip.tokenize(prompt).to(device)
with torch.no_grad():
    text_embedding = model.encode_text(text_tokens).float()
generated_image = generator(z, text_embedding)
generated_image = generated_image.squeeze(0).permute(1, 2, 0).cpu().detach().numpy()
generated_image = (generated_image * 0.5 + 0.5) * 255.0
generated_image = Image.fromarray(generated_image.astype('uint8'))
generated_image.show() 