In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from ZeroShotDataset import ZeroShotDataset
import pandas as pd
import pytorch_lightning as pl
from params import *
from transformers import CLIPProcessor, CLIPModel

In [None]:
class CLIPImageTextMerge(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.mha = nn.MultiheadAttention(d_model, n_heads)
        self.outfc = nn.Linear(d_model, d_model)
        self.norm = nn.LayerNorm(d_model)

    def forward(self, image, text):
        # image shape: (batch_size, 50, d_model)
        # text shape: (batch_size, 77, d_model)

        # merge image and text, the output shape should be the same as the image
        # shape: (batch_size, 50, d_model)
        image = self.mha(image, text, text)[0]
        image = self.norm(image + image)
        image = self.outfc(image)
        return image

In [None]:
class ImageAutoEncoder(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        encoder_dims = [3, 256, 512]
        decoder_dims = [512, 256, 3]

        # image shape (batch_size, 3, 256, 256)
        self.encoder = nn.Sequential(
            nn.Conv2d(encoder_dims[0], encoder_dims[1], 3, stride=2, padding=1),
            nn.GELU(),
            nn.BatchNorm2d(encoder_dims[1]),
            nn.Conv2d(encoder_dims[1], encoder_dims[2], 3, stride=2, padding=1),
            nn.Sigmoid(),
            nn.BatchNorm2d(encoder_dims[2])
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(decoder_dims[0], decoder_dims[1], 3, stride=2, padding=1, output_padding=1),
            nn.GELU(),
            nn.BatchNorm2d(decoder_dims[1]),
            nn.ConvTranspose2d(decoder_dims[1], decoder_dims[2], 3, stride=2, padding=1, output_padding=1),
        )

    def forward(self, x):
        ex = self.encoder(x)
        x = self.decoder(ex)
        return x, ex
    
    def training_step(self, batch, batch_idx):
        x, _, _ = batch
        x_hat, _ = self(x)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, _, _ = batch
        x_hat, _ = self(x)
        loss = F.mse_loss(x_hat, x)
        self.log('val_loss', loss, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:
x_test = torch.randn(1, 3, 224, 224)
model = ImageAutoEncoder()
x_hat, ex = model(x_test)

In [None]:
x_hat.shape, ex.shape

In [None]:
# torch random split 
from torch.utils.data import random_split

train_df = pd.read_csv("ProcessedDatasetStuff/csv/train.csv")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
complete_dataset = ZeroShotDataset(
    df = train_df, 
    # image_folder = 'C:/Datasets/COCO/P/ProcessedDataset/images/train/',
    image_folder = 'ProcessedDatasetStuff/images/train/',
    # mask_folder = "C:/Datasets/COCO/P/ProcessedDataset/masks/train/",
    mask_folder = 'ProcessedDatasetStuff/masks/train/',
    mask_size = 56,
    templates = TrainParams.TEMPLATES, 
    unseen_classes = TrainParams.UNSEEN_CLASSES, 
    image_processor = clip_processor, 
    tokenizer = clip_processor.tokenizer, 
    filter_unseen = False,
    filter_seen = True
)

train_size = int(TrainParams.TRAIN_VAL_SPLIT * len(complete_dataset))
val_size = len(complete_dataset) - train_size
train_dataset, val_dataset = random_split(complete_dataset, [train_size, val_size])

In [None]:
trainloader = torch.utils.data.DataLoader(complete_dataset, batch_size=16, shuffle=True, collate_fn=complete_dataset.collate_fn)
valloader = torch.utils.data.DataLoader(val_dataset, batch_size=16, shuffle=True, collate_fn=complete_dataset.collate_fn)

In [None]:
model = ImageAutoEncoder()
trainer = pl.Trainer(max_epochs=10, accelerator='gpu')
trainer.fit(model, trainloader, valloader)

In [None]:
# save model
torch.save(model.state_dict(), "CLIPImageAutoEncoder.pt")