In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from lib.data import (
    
    ImageTokenDatasetClassLabel,
    ImageTokenDatasetSemanticLabel
)
from lib.models import (
    ConditionalTransformerDecoderConfig, ConditionalTransformerDecoder,
    VanillaTransformerDecoderConfig, VanillaTransformerDecoder
)
from lib.training import (
    ConditionalTransformerTrainer,
    UnconditionalTransformerTrainer
)

device = torch.device("cuda:0")

In [2]:
config = VanillaTransformerDecoderConfig(
    d_model=32,
    n_layers=1,
    n_heads=4
)

model = VanillaTransformerDecoder(config)

#params: 79296


In [None]:
trainer = UnconditionalTransformerTrainer(
    train_dataset=ImageTokenDataset(),
    batch_size=8,
    model=model,
    lr=2.25e-5,
    save_every_epoch=3,
    savepath="./logs/uncond_transformer",
    device=device
)

In [4]:
trainer.load_checkpoint("./logs/uncond_transformer/ckpt_e10")

...checkpoint [./logs/uncond_transformer/ckpt_e10] loaded!


In [5]:
trainer.train(15)

Train Epoch: 11: 100%|██████████| 2067/2067 [00:12<00:00, 165.91it/s, loss=6.05]
Train Epoch: 12: 100%|██████████| 2067/2067 [00:11<00:00, 177.16it/s, loss=6.05]
Train Epoch: 13: 100%|██████████| 2067/2067 [00:11<00:00, 178.52it/s, loss=6.05]
Train Epoch: 14: 100%|██████████| 2067/2067 [00:11<00:00, 179.05it/s, loss=6.04]
Train Epoch: 15: 100%|██████████| 2067/2067 [00:11<00:00, 177.04it/s, loss=6.04]


In [2]:
dataset = ImageTokenDatasetSemanticLabel()
dataset.n_classes

39

In [8]:
config = ConditionalTransformerDecoderConfig(
    d_model=16,
    n_layers=1,
    n_heads=2,
    class_prompt_length=5,
    n_classes=dataset.n_classes
)

model = ConditionalTransformerDecoder(config)

#params: 40192


In [9]:
trainer = ConditionalTransformerTrainer(
    train_dataset=dataset,
    batch_size=8,
    model=model,
    lr=2.25e-5,
    save_every_epoch=3,
    savepath="./logs/cond_transformer_class_label",
    device=device
)

In [11]:
trainer.load_checkpoint("./logs/cond_transformer_class_label/ckpt_e006.pt")

...checkpoint [./logs/cond_transformer_class_label/ckpt_e006.pt] loaded!


In [12]:
trainer.train(9)

Train Epoch: 7: 100%|██████████| 671/671 [00:04<00:00, 158.11it/s, loss=6.23]
Train Epoch: 8: 100%|██████████| 671/671 [00:04<00:00, 163.84it/s, loss=6.19]
Train Epoch: 9: 100%|██████████| 671/671 [00:04<00:00, 166.71it/s, loss=6.16]


In [15]:
torch.load("./logs/cond_transformer_class_label/ckpt_e009.pt")["loss_history"]

{'epoch_losses': [7.0436549684152165,
  6.865969944284498,
  6.661320198132988,
  6.497320607237951,
  6.379328360323046,
  6.294602692571376,
  6.233691889198633,
  6.18987272677585,
  6.157920587968897],
 'batch_losses': [7.1160125732421875,
  7.115261554718018,
  7.131227970123291,
  7.091777801513672,
  7.102847576141357,
  7.093673229217529,
  7.1206183433532715,
  7.11196231842041,
  7.113897323608398,
  7.124659061431885,
  7.095703125,
  7.128131866455078,
  7.115018844604492,
  7.087567329406738,
  7.112882614135742,
  7.09647274017334,
  7.105193138122559,
  7.09033203125,
  7.092896938323975,
  7.102370738983154,
  7.108880996704102,
  7.10540771484375,
  7.097070693969727,
  7.101273536682129,
  7.112944602966309,
  7.066572666168213,
  7.107778072357178,
  7.112666606903076,
  7.074819087982178,
  7.113194942474365,
  7.120478630065918,
  7.090183734893799,
  7.125960350036621,
  7.087242603302002,
  7.096743106842041,
  7.092584609985352,
  7.122498512268066,
  7.08397769