In [1]:
import torch as t
import torch.nn.functional as F
from tqdm import tqdm
from models.Autoencoder import Autoencoder
from torchvision import datasets, transforms

In [2]:
bs = 512
train_ds = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
train_dl = t.utils.data.DataLoader(dataset=train_ds, batch_size=bs, shuffle=True, drop_last=True)

In [3]:
device = t.device('cuda') if t.cuda.is_available() else 'cpu'
model = Autoencoder(train_ds[0][0][None], in_c=1, enc_out_c=[32, 64, 64, 64],
                    enc_ks=[3, 3, 3, 3], enc_pads=[1, 1, 0, 1], enc_strides=[1, 2, 2, 1],
                    dec_out_c=[64, 64, 32, 1], dec_ks=[3, 3, 3, 3], dec_strides=[1, 2, 2, 1],
                    dec_pads=[1, 0, 1, 1], dec_op_pads=[0, 1, 1, 0], z_dim=2)
model.cuda(device)


Autoencoder(
  (enc_conv_layers): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=

In [4]:
optimizer = t.optim.Adam(model.parameters(), lr=5e-4, betas=(.9, .99), weight_decay=1e-2)
model.train()

for epoch in tqdm(range(20)):
    if epoch == 10:
        optimizer = t.optim.Adam(model.parameters(), lr=2e-4, betas=(.9, .99), weight_decay=1e-2)
    for i, (data, _) in enumerate(train_dl):
        data = data.to(device)
        optimizer.zero_grad()
        pred = model(data)
        loss = F.mse_loss(pred, data)
        loss.backward()
        optimizer.step()
        if i % 33 == 0:
            print(loss)


  0%|          | 0/20 [00:00<?, ?it/s]

tensor(0.3574, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.1034, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0816, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0712, device='cuda:0', grad_fn=<MseLossBackward>)


  5%|▌         | 1/20 [00:34<10:52, 34.34s/it]

tensor(0.0706, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0634, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0591, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0582, device='cuda:0', grad_fn=<MseLossBackward>)


 10%|█         | 2/20 [01:08<10:17, 34.32s/it]

tensor(0.0551, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0561, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0540, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0524, device='cuda:0', grad_fn=<MseLossBackward>)


 15%|█▌        | 3/20 [01:43<09:45, 34.43s/it]

tensor(0.0531, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0502, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0512, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0501, device='cuda:0', grad_fn=<MseLossBackward>)


 20%|██        | 4/20 [02:18<09:12, 34.55s/it]

tensor(0.0516, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0501, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0497, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0495, device='cuda:0', grad_fn=<MseLossBackward>)


 25%|██▌       | 5/20 [02:55<08:50, 35.36s/it]

tensor(0.0491, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0481, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0482, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0487, device='cuda:0', grad_fn=<MseLossBackward>)


 30%|███       | 6/20 [03:31<08:18, 35.63s/it]

tensor(0.0494, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0486, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0477, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0481, device='cuda:0', grad_fn=<MseLossBackward>)


 35%|███▌      | 7/20 [04:06<07:39, 35.31s/it]

tensor(0.0490, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0481, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0492, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0483, device='cuda:0', grad_fn=<MseLossBackward>)


 40%|████      | 8/20 [04:41<07:04, 35.36s/it]

tensor(0.0467, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0493, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0464, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0489, device='cuda:0', grad_fn=<MseLossBackward>)


 45%|████▌     | 9/20 [05:18<06:33, 35.75s/it]

tensor(0.0480, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0482, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0471, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0474, device='cuda:0', grad_fn=<MseLossBackward>)


 50%|█████     | 10/20 [05:53<05:57, 35.71s/it]

tensor(0.0472, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0461, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0462, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0468, device='cuda:0', grad_fn=<MseLossBackward>)


 55%|█████▌    | 11/20 [06:30<05:23, 35.95s/it]

tensor(0.0453, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0452, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0456, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0453, device='cuda:0', grad_fn=<MseLossBackward>)


 60%|██████    | 12/20 [07:06<04:48, 36.02s/it]

tensor(0.0458, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0451, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0457, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0457, device='cuda:0', grad_fn=<MseLossBackward>)


 65%|██████▌   | 13/20 [07:41<04:10, 35.77s/it]

tensor(0.0465, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0451, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0463, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0455, device='cuda:0', grad_fn=<MseLossBackward>)


 70%|███████   | 14/20 [08:17<03:33, 35.61s/it]

tensor(0.0462, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0468, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0458, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0468, device='cuda:0', grad_fn=<MseLossBackward>)


 75%|███████▌  | 15/20 [08:53<02:58, 35.79s/it]

tensor(0.0444, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0457, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0465, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0444, device='cuda:0', grad_fn=<MseLossBackward>)


 80%|████████  | 16/20 [09:28<02:22, 35.65s/it]

tensor(0.0452, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0462, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0459, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0453, device='cuda:0', grad_fn=<MseLossBackward>)


 85%|████████▌ | 17/20 [10:03<01:46, 35.51s/it]

tensor(0.0464, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0464, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0454, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0454, device='cuda:0', grad_fn=<MseLossBackward>)


 90%|█████████ | 18/20 [10:39<01:10, 35.42s/it]

tensor(0.0459, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0450, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0457, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0452, device='cuda:0', grad_fn=<MseLossBackward>)


 95%|█████████▌| 19/20 [11:14<00:35, 35.32s/it]

tensor(0.0448, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0470, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0458, device='cuda:0', grad_fn=<MseLossBackward>)
tensor(0.0459, device='cuda:0', grad_fn=<MseLossBackward>)


100%|██████████| 20/20 [11:48<00:00, 35.06s/it]


In [5]:
print(loss)
t.save(model.state_dict(), 'models/state_dicts/03_01.pth')

tensor(0.0455, device='cuda:0', grad_fn=<MseLossBackward>)
