In [1]:
import torch
from torch import nn
from torchvision import transforms
from pathlib import Path

from dataset import *
from data_preprocessing import *
from utils import *
from diffusers.models import AutoencoderKL
# from diffusers import DDPMScheduler, DDPMPipeline
from unet import *
from attention import *
from time_embedding import *
from ddpm import *

In [2]:
# torch.cuda.empty_cache()

In [3]:
batch_size = 20
device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
# unet = UNET(n_classes=2, in_channels=4, out_channels=4, channels=(128, 256, 512, 768), attn_channs=16).to(device)
unet = UNET(n_classes=2, in_channels=4, out_channels=4, channels=(32,64,128,256)).to(device)

In [5]:
ddpm = DDPM(0.0001, 0.02, 1000, device)

In [None]:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").to(device).requires_grad_(False)


In [6]:
train_data_folder = "./data/256/"
train_latent_folder = Path("./data/latents256/")
test_data_folder = "./data/256test/"
test_latent_folder = Path("./data/latents256test/")


In [None]:
# train_data_folder = "./data/Linnaeus 5/256x256/train"
# train_latent_folder = Path("./data/Linnaeus 5/32x32/train")
# test_data_folder = "./data/Linnaeus 5/256x256/test"
# test_latent_folder = Path("./data/Linnaeus 5/32x32/test/")

In [None]:
def init_weights(m):
    if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, a=0.1)
unet.apply(init_weights);

In [None]:
transform = transforms.Compose([
                transforms.ToTensor(), 
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
                transforms.Lambda(lambda x: x.float())
            ])

In [None]:
data_to_latents(model=vae, 
                  data_folder=train_data_folder, 
                  latent_folder=train_latent_folder, 
                  transform=transform, 
                  batch_size=5, 
                  shuffle=True
               )

In [None]:
data_to_latents(model=vae, 
                  data_folder=test_data_folder, 
                  latent_folder=test_latent_folder, 
                  transform=transform, 
                  batch_size=5, 
                  shuffle=False
               )

In [None]:
latent_transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[-1.2638,  0.7728,  0.6195, -0.9007], std=[6.1378, 7.1625, 4.4116, 5.5606]),
                # transforms.Lambda(lambda x: x.float())
            ])


In [7]:
train_latent_dl = latent_dataloader(latent_folder=train_latent_folder, 
                                    latent_transform=None, 
                                    batch_size=20, 
                                    shuffle=True
                                )

test_latent_dl = latent_dataloader(latent_folder=test_latent_folder, 
                                    latent_transform=None, 
                                    batch_size=5, 
                                    shuffle=False
                                )

In [None]:
for X, y in train_latent_dl:
    X, y = X.to(device), y.to(device)
    print(X.shape)
#     print(X)
#     print(y)
#     print(X.mean((0, 2, 3)))
#     print(X.std((0, 2, 3)))
    show_images_grid(X)
    xt, t, noise = ddpm.schedule(X)
    # pred = unet(xt, t, y)
    show_images_grid(xt)
    break

In [None]:
for X, y in train_latent_dl:
    X, y = X.to(device), y.to(device)
    print(X.shape)
    print(y.shape)

In [8]:
lr = 1e-3
epochs = 300
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(unet.parameters(), lr=lr)
schedular = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, total_steps=epochs*len(train_latent_dl))

In [13]:
from IPython.display import display
from tqdm.notebook import trange

train_losses = []
test_losses = []

def train(model, epochs=None):
    model.to(device)
    progress_bar = trange(epochs, desc="Progress")
    display(progress_bar)
    for epoch in progress_bar:
        model.train()
        total_loss = 0
        for X, y in train_latent_dl:
            X, y = X.to(device), y.to(device)
            xt, t, noise = ddpm.schedule(X)
            pred = model(xt, t, y)
            loss = loss_fn(pred, noise)
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        schedular.step()
        current_epoch = epoch+1

        model.eval()
        with torch.no_grad():
            total_val_loss = 0
            for val_X, val_y in test_latent_dl:
                val_X, val_y = val_X.to(device), val_y.to(device)
                val_xt, val_t, val_noise = ddpm.schedule(val_X)
                val_pred = model(val_xt, val_t, val_y)
                val_loss = loss_fn(val_pred, val_noise)
                test_losses.append(val_loss.item())

        total_val_loss += val_loss.item()
        total_loss /= len(train_latent_dl)
        train_losses.append(loss.item())
        total_val_loss /= len(test_latent_dl)
        if epoch % 10 == 0:
            torch.save(model.state_dict(), f"./Saved Models/kruto_latent/unet_{current_epoch}.pth")
        print(f"Epoch: {current_epoch} - Loss: {total_loss} | Val Loss: {total_val_loss}")

In [14]:
train(unet, epochs=epochs)

Progress:   0%|          | 0/300 [00:00<?, ?it/s]

<tqdm.notebook.tqdm_notebook at 0x250880f05b0>

Epoch: 1 - Loss: 1.0235024333000182 | Val Loss: 1.0164611637592316
Epoch: 2 - Loss: 1.0049712419509889 | Val Loss: 0.9978498071432114
Epoch: 3 - Loss: 0.9949985563755035 | Val Loss: 0.9791611582040787
Epoch: 4 - Loss: 0.9869266211986542 | Val Loss: 0.9802265167236328
Epoch: 5 - Loss: 0.9758458256721496 | Val Loss: 0.9707391485571861
Epoch: 6 - Loss: 0.958261251449585 | Val Loss: 0.9286753013730049
Epoch: 7 - Loss: 0.9457260012626648 | Val Loss: 0.9446269273757935
Epoch: 8 - Loss: 0.9338752746582031 | Val Loss: 0.9222942665219307
Epoch: 9 - Loss: 0.9185122489929199 | Val Loss: 0.9062851518392563
Epoch: 10 - Loss: 0.9091858208179474 | Val Loss: 0.9108386784791946
Epoch: 11 - Loss: 0.9051381945610046 | Val Loss: 0.8723679706454277
Epoch: 12 - Loss: 0.8675478458404541 | Val Loss: 0.8242794126272202
Epoch: 13 - Loss: 0.8497046828269958 | Val Loss: 0.8929253295063972
Epoch: 14 - Loss: 0.8529496252536773 | Val Loss: 0.821945920586586
Epoch: 15 - Loss: 0.8180499911308289 | Val Loss: 0.82930105

In [21]:
unet.state_dict()

OrderedDict([('timestep_embedding.timestep_mlp.0.weight',
              tensor([0.9062, 0.9120, 0.9042, 0.9126, 0.9070, 0.9157, 0.9408, 0.9843, 1.0243,
                      1.0226, 1.0432, 1.0249, 1.0156, 1.0156, 1.0210, 1.0174, 0.9083, 0.8988,
                      0.9021, 0.9169, 0.9076, 0.9151, 0.9298, 0.9941, 1.0554, 1.0462, 1.0358,
                      1.0267, 1.0253, 1.0213, 0.9976, 1.0329], device='cuda:0')),
             ('timestep_embedding.timestep_mlp.0.bias',
              tensor([-0.0801, -0.0693, -0.0807, -0.0726, -0.0684, -0.0654, -0.0311,  0.0178,
                       0.0550,  0.0303,  0.0503,  0.0277,  0.0179,  0.0253,  0.0256,  0.0247,
                      -0.0729, -0.0891, -0.0721, -0.0685, -0.0797, -0.0792, -0.0630,  0.0272,
                       0.0339,  0.0295,  0.0164,  0.0165,  0.0129,  0.0057, -0.0044,  0.0242],
                     device='cuda:0')),
             ('timestep_embedding.timestep_mlp.0.running_mean',
              tensor([-0.0327, -0.0220, -

In [20]:
torch.save(unet.state_dict(), "./Saved Models/kruto_latent/unet_300.pth")

In [15]:
def plot_losses(train_losses, val_losses):
    epochs = range(1, len(train_losses) + 1)

    plt.plot(epochs, train_losses, label='Training Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    
    plt.title('Training and Validation Losses')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

In [None]:
plot_losses(train_losses, test_losses)

In [16]:
X, y = next(iter(test_latent_dl))
X = X.to(device)
y = y.to(device)