In [29]:
import wandb
import torch
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from PIL import Image
import os
from torch.optim.lr_scheduler import StepLR
import random
from CombinationFunctions import ImageInputToDiT, NDiTModule, Decoder, TimeEmbedding, TextEmbedding
from tqdm import tqdm

if(torch.cuda.is_available()):
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print("Device: ", device)

Device:  cpu


In [None]:
wandb.login()

wandb.init(
    project="diffusion-transformer",  
    name="experiment-1",    
    id="4s8pcvm5",  
    resume="allow",
)


[34m[1mwandb[0m: Currently logged in as: [33misanand[0m ([33misanand-uc-san-diego[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [30]:
class ImageTextData(Dataset):
    def __init__(self, data, transform = None, rootDir = ""):
        super().__init__()
        self.data = data
        self.transform = transform
        self.rootDir = rootDir

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]

        image_path = os.path.join(self.rootDir, row['imagePath'])
        captions = [
            row['caption1'],
            row['caption2'],
            row['caption3'],
            row['caption4'],
            row['caption5']
        ]

        caption = random.choice(captions)

        image = Image.open(image_path)
        image = self.transform(image)
        return image, caption


data = pd.read_csv("dataset/COCO2017.csv")
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),                 
    transforms.Normalize([0.5]*3, [0.5]*3)])

idata = ImageTextData(data, transform)

image, caption = idata.__getitem__(2000)
image.shape, caption

(torch.Size([3, 512, 512]),
 'Two motorcycles parked next to each other on a lush green field.')

In [31]:
class FinalModel(nn.Module):
    def __init__(self, latentSize, latentChannel, embedDimension, patchSize, T, numHeads, blocks, dropout, beta_schedule = "squaredcos_cap_v2", modelName="mit-han-lab/dc-ae-f64c128-in-1.0-diffusers"):
        super().__init__()

        self.input = ImageInputToDiT(latentSize, latentChannel, embedDimension, patchSize, T, beta_schedule, modelName)

        self.timeEmbedding = TimeEmbedding(embedDimension)
        self.textEmbeding = TextEmbedding()

        self.ditBlocks = NDiTModule(blocks, embedDimension, numHeads, dropout)
        
        self.output = Decoder(embedDimension, latentSize, latentChannel, patchSize, T, beta_schedule, modelName)

    def forward(self, x, captions, t):

        batchSize, channels, height, width = x.shape
        noisedLatents, noise = self.input(x, t)

        timeEmbed = self.timeEmbedding(t)
        textembed = self.textEmbeding(captions)

        ditOutput = self.ditBlocks(noisedLatents, textembed, timeEmbed)
        
        predictedNoise = self.output(ditOutput)
        
        return predictedNoise, noise
    

fModel = FinalModel(8, 128, 768, 2, 1000, 12, 12, 0.2)



In [32]:
IMAGEHEIGHT = 512
IMAGEWIDTH = 512
EMBEDDINGDIM = 768
BATCHSIZE = 1
INCHANNELS = 3
LATENTSIZE = 8
LATENTCHANNEL = 128
PATCHSIZE = 2
T = 1000
DITBLOCK = 1#2
HEADS = 12
dropout = 0.2

In [33]:
epochs = 1000
data = pd.read_csv("dataset/COCO2017.csv")
transform = transforms.Compose([
    transforms.Resize((IMAGEHEIGHT, IMAGEWIDTH)),
    transforms.ToTensor(),                 
    transforms.Normalize([0.5]*3, [0.5]*3)])


model = FinalModel(latentSize=LATENTSIZE, latentChannel=LATENTCHANNEL, embedDimension=EMBEDDINGDIM, patchSize=PATCHSIZE,
                    T = T, numHeads=HEADS, blocks=DITBLOCK, dropout=dropout,
                    beta_schedule = "squaredcos_cap_v2", modelName = "mit-han-lab/dc-ae-f64c128-in-1.0-diffusers")


data = pd.read_csv("dataset/COCO2017.csv")
data = data.iloc[:1000]
transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor(),                 
    transforms.Normalize([0.5]*3, [0.5]*3)])

torchDataset = ImageTextData(data, transform)
dataloader = DataLoader(torchDataset, batch_size=BATCHSIZE, shuffle = True, num_workers=0)
model = torch.nn.DataParallel(model)
model.to(device)




lossFn =  nn.MSELoss()
optimizer = torch.optim.AdamW(params=model.parameters(), lr=2e-5, weight_decay=3e-2, eps=1e-10)
scheduler = StepLR(optimizer, step_size=10, gamma=0.5)

In [34]:

start_epoch = 0

checkpoint_dir = os.path.join("", "model")
checkpoint_path = os.path.join(checkpoint_dir, "dit.pt")

if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resuming from epoch {start_epoch}")
else:
    print("Loading pretrained model...")

Loading pretrained model...


In [None]:
for each_epoch in range(start_epoch, epochs):
    model.train()
    
    loop = tqdm(dataloader, f"{each_epoch}/{epochs}")
    ditloss = 0.0
    for X, captions in loop:
        t = torch.randint(0, T, (BATCHSIZE,), device=device).long()

        predictedNoise, noise = model(X, captions, t)
       
        # print(predictedNoise.shape, noise.shape)
    #     break
    # break
        loss = lossFn(predictedNoise, noise)
        ditloss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        loop.set_postfix({
            "DIT Loss": f"{ditloss}"
        })

    ditloss /= len(dataloader)   

    os.makedirs(checkpoint_path, exist_ok=True)
    torch.save({
        'epoch': each_epoch,
        'model_state_dict': model.module.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict()
    }, checkpoint_path)
    
    
    wandb.log({
        "Learning Rate": optimizer.param_groups[0]['lr'],
        "Decoder Loss": ditloss
    })
    scheduler.step()


0/1000:   0%|          | 1/1000 [00:10<2:47:09, 10.04s/it, DIT Loss=1.4896643161773682]

torch.Size([1, 128, 8, 8]) torch.Size([1, 128, 8, 8])


0/1000:   0%|          | 2/1000 [00:17<2:18:51,  8.35s/it, DIT Loss=2.995692014694214] 

torch.Size([1, 128, 8, 8]) torch.Size([1, 128, 8, 8])


0/1000:   0%|          | 3/1000 [00:23<2:05:54,  7.58s/it, DIT Loss=4.456230163574219]

torch.Size([1, 128, 8, 8]) torch.Size([1, 128, 8, 8])


0/1000:   0%|          | 3/1000 [00:27<2:32:02,  9.15s/it, DIT Loss=4.456230163574219]


KeyboardInterrupt: 