In [5]:
%pip install vit_pytorch

Collecting vit_pytorch
  Downloading vit_pytorch-1.6.9-py3-none-any.whl.metadata (65 kB)
     ---------------------------------------- 0.0/65.7 kB ? eta -:--:--
     ---------------------------------------- 65.7/65.7 kB 1.7 MB/s eta 0:00:00
Collecting einops>=0.7.0 (from vit_pytorch)
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Downloading vit_pytorch-1.6.9-py3-none-any.whl (119 kB)
   ---------------------------------------- 0.0/119.7 kB ? eta -:--:--
   ---------------------------------------- 119.7/119.7 kB 7.3 MB/s eta 0:00:00
Downloading einops-0.8.0-py3-none-any.whl (43 kB)
   ---------------------------------------- 0.0/43.2 kB ? eta -:--:--
   ---------------------------------------- 43.2/43.2 kB ? eta 0:00:00
Installing collected packages: einops, vit_pytorch
Successfully installed einops-0.8.0 vit_pytorch-1.6.9
Note: you may need to restart the kernel to use updated packages.


In [1]:
from dataloader import training_dataset, testing_dataset
import torch
from torch.utils.data import DataLoader
from vit_pytorch import ViT, SimpleViT
from torch import nn
from tqdm import tqdm
import os

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [3]:
training_data = DataLoader(training_dataset, batch_size=32, shuffle=True)
testing_data = DataLoader(testing_dataset, batch_size=32, shuffle=True)

In [4]:
vision_model = ViT(
    image_size=128 * 3,
    patch_size=32,
    num_classes=39,
    dim=512,
    depth=6,
    heads=16,
    mlp_dim=1024,
    dropout = 0.1,
    emb_dropout = 0.1
).to('cuda')

# vit_model = nn.Sequential(vision_model, nn.Sigmoid())

In [5]:
loss_fn = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(vision_model.parameters(), lr=1e-7)

In [6]:
state = {'loss': [], 'epochs': [], 'test_loss': []}
def train_loop(dataloader, testing_data, model, loss_fn, optimizer, epochs=10, save_path='./saves'):
    global state
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    all_exps = [int(elm.replace('exp_', '')) if elm != '.ipynb_checkpoints' else -1 for elm in os.listdir(save_path)]

    current_num = max(all_exps) if (len(all_exps) > 0) else 1
    save_path = save_path + '/exp_' + str(current_num + 1)
    os.mkdir(save_path)

    model.to(device)

    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs}")
        model.train()
        loop = tqdm(dataloader, total=len(dataloader), leave=True)
        total_loss = 0

        for batch_idx, (data, targets) in enumerate(loop):
            data = data.to(device)
            targets = targets.to(device)

            # Forward pass
            scores = model(data)
            loss = loss_fn(scores, targets)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Update progress bar
            loop.set_postfix(loss=loss.item())

        testing_loss = 0
        for test in tqdm(testing_data):
            X, y = test
            with torch.no_grad():
                model.eval()

                score = model(X)
                testing_loss =+ loss_fn(score, y)

                model.train()

        print(f"Epoch {epoch + 1} average loss: {total_loss / len(dataloader)} with testing loss of {testing_loss / len(testing_data)}")

        state['loss'].append(total_loss / len(dataloader))
        state['epochs'].append(epoch + 1)
        state['test_loss'].append((testing_loss / len(testing_data)).cpu())

        # Save the model after each epoch
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': total_loss / len(dataloader),
            'testing_loss': (testing_loss / len(testing_data)).cpu()
        }, f"{save_path}/model_epoch_{epoch + 1}.pth")

    print("Training complete!")

In [7]:
train_loop(training_data, testing_data, vision_model, loss_fn, optimizer, epochs=100)

Epoch 1/100


100%|██████████| 64/64 [01:16<00:00,  1.20s/it, loss=0.697]
100%|██████████| 13/13 [00:13<00:00,  1.06s/it]


Epoch 1 average loss: 0.6990987602621317 with testing loss of 0.052406664937734604
Epoch 2/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=0.678]
100%|██████████| 13/13 [00:12<00:00,  1.07it/s]


Epoch 2 average loss: 0.6859181514009833 with testing loss of 0.05120355263352394
Epoch 3/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.668]
100%|██████████| 13/13 [00:12<00:00,  1.08it/s]


Epoch 3 average loss: 0.6730914637446404 with testing loss of 0.051010239869356155
Epoch 4/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=0.649]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 4 average loss: 0.6605385644361377 with testing loss of 0.05004071816802025
Epoch 5/100


100%|██████████| 64/64 [01:08<00:00,  1.06s/it, loss=0.636]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 5 average loss: 0.6482885256409645 with testing loss of 0.050910379737615585
Epoch 6/100


100%|██████████| 64/64 [01:13<00:00,  1.14s/it, loss=0.621]
100%|██████████| 13/13 [00:13<00:00,  1.03s/it]


Epoch 6 average loss: 0.6362931951880455 with testing loss of 0.047200679779052734
Epoch 7/100


100%|██████████| 64/64 [01:12<00:00,  1.14s/it, loss=0.624]
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 7 average loss: 0.6246377946808934 with testing loss of 0.04694366082549095
Epoch 8/100


100%|██████████| 64/64 [01:14<00:00,  1.16s/it, loss=0.604]
100%|██████████| 13/13 [00:13<00:00,  1.03s/it]


Epoch 8 average loss: 0.6131544290110469 with testing loss of 0.046915192157030106
Epoch 9/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=0.59] 
100%|██████████| 13/13 [00:13<00:00,  1.07s/it]


Epoch 9 average loss: 0.6019465588033199 with testing loss of 0.04462288320064545
Epoch 10/100


100%|██████████| 64/64 [01:13<00:00,  1.15s/it, loss=0.59] 
100%|██████████| 13/13 [00:13<00:00,  1.02s/it]


Epoch 10 average loss: 0.5910469302907586 with testing loss of 0.043909091502428055
Epoch 11/100


100%|██████████| 64/64 [01:13<00:00,  1.14s/it, loss=0.573]
100%|██████████| 13/13 [00:13<00:00,  1.01s/it]


Epoch 11 average loss: 0.5803225692361593 with testing loss of 0.04383718967437744
Epoch 12/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=0.567]
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 12 average loss: 0.5698895566165447 with testing loss of 0.0455230213701725
Epoch 13/100


100%|██████████| 64/64 [01:10<00:00,  1.11s/it, loss=0.554]
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 13 average loss: 0.5596702257171273 with testing loss of 0.04145966097712517
Epoch 14/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=0.548]
100%|██████████| 13/13 [00:13<00:00,  1.03s/it]


Epoch 14 average loss: 0.5497210854664445 with testing loss of 0.041913777589797974
Epoch 15/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=0.531]
100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


Epoch 15 average loss: 0.5399793870747089 with testing loss of 0.0395377017557621
Epoch 16/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.526]
100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 16 average loss: 0.5305149285122752 with testing loss of 0.04330543056130409
Epoch 17/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.503]
100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 17 average loss: 0.521234811283648 with testing loss of 0.04223942756652832
Epoch 18/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=0.5]  
100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 18 average loss: 0.5122687043622136 with testing loss of 0.037988629192113876
Epoch 19/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.488]
100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


Epoch 19 average loss: 0.5035063680261374 with testing loss of 0.03735070303082466
Epoch 20/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.493]
100%|██████████| 13/13 [00:11<00:00,  1.09it/s]


Epoch 20 average loss: 0.4950480950064957 with testing loss of 0.041801899671554565
Epoch 21/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.49] 
100%|██████████| 13/13 [00:12<00:00,  1.07it/s]


Epoch 21 average loss: 0.48679743660613894 with testing loss of 0.0371873676776886
Epoch 22/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=0.465]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 22 average loss: 0.4787101433612406 with testing loss of 0.038378726691007614
Epoch 23/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=0.477]
100%|██████████| 13/13 [00:12<00:00,  1.08it/s]


Epoch 23 average loss: 0.47097973013296723 with testing loss of 0.03541867434978485
Epoch 24/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.449]
100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


Epoch 24 average loss: 0.4633476766757667 with testing loss of 0.03754853829741478
Epoch 25/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.456]
100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 25 average loss: 0.45606185775250196 with testing loss of 0.035567205399274826
Epoch 26/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.451]
100%|██████████| 13/13 [00:12<00:00,  1.07it/s]


Epoch 26 average loss: 0.4489616462960839 with testing loss of 0.03699162229895592
Epoch 27/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=0.426]
100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


Epoch 27 average loss: 0.442017185036093 with testing loss of 0.03374342992901802
Epoch 28/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=0.43] 
100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 28 average loss: 0.43538736971095204 with testing loss of 0.03238853067159653
Epoch 29/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.422]
100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 29 average loss: 0.4289335482753813 with testing loss of 0.03286430984735489
Epoch 30/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=0.425]
100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 30 average loss: 0.4227253133431077 with testing loss of 0.03333330526947975
Epoch 31/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.406]
100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 31 average loss: 0.4166466323658824 with testing loss of 0.03083968348801136
Epoch 32/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=0.413]
100%|██████████| 13/13 [00:13<00:00,  1.00s/it]


Epoch 32 average loss: 0.4108598087914288 with testing loss of 0.029493045061826706
Epoch 33/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=0.397]
100%|██████████| 13/13 [00:12<00:00,  1.07it/s]


Epoch 33 average loss: 0.40518762404099107 with testing loss of 0.030325831845402718
Epoch 34/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=0.397]
100%|██████████| 13/13 [00:12<00:00,  1.07it/s]


Epoch 34 average loss: 0.39976340578868985 with testing loss of 0.029376138001680374
Epoch 35/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.407]
100%|██████████| 13/13 [00:12<00:00,  1.08it/s]


Epoch 35 average loss: 0.3945610555820167 with testing loss of 0.03291512280702591
Epoch 36/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.401]
100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


Epoch 36 average loss: 0.38948369259014726 with testing loss of 0.02560555376112461
Epoch 37/100


100%|██████████| 64/64 [01:06<00:00,  1.05s/it, loss=0.371]
100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 37 average loss: 0.3844992551021278 with testing loss of 0.03201055899262428
Epoch 38/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.369]
100%|██████████| 13/13 [00:12<00:00,  1.07it/s]


Epoch 38 average loss: 0.3797812694683671 with testing loss of 0.03250329568982124
Epoch 39/100


100%|██████████| 64/64 [01:06<00:00,  1.03s/it, loss=0.366]
100%|██████████| 13/13 [00:12<00:00,  1.07it/s]


Epoch 39 average loss: 0.3752288958057761 with testing loss of 0.027523266151547432
Epoch 40/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.374]
100%|██████████| 13/13 [00:12<00:00,  1.07it/s]


Epoch 40 average loss: 0.37087072897702456 with testing loss of 0.028637852519750595
Epoch 41/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.398]
100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 41 average loss: 0.36672773072496057 with testing loss of 0.024083290249109268
Epoch 42/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=0.369]
100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 42 average loss: 0.36255497112870216 with testing loss of 0.027656463906168938
Epoch 43/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.336]
100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 43 average loss: 0.35851036151871085 with testing loss of 0.027944404631853104
Epoch 44/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.368]
100%|██████████| 13/13 [00:12<00:00,  1.00it/s]


Epoch 44 average loss: 0.3548305034637451 with testing loss of 0.026396827772259712
Epoch 45/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.362]
100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


Epoch 45 average loss: 0.3511568955145776 with testing loss of 0.026970919221639633
Epoch 46/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.341]
100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 46 average loss: 0.34756343672052026 with testing loss of 0.02858647145330906
Epoch 47/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=0.339]
100%|██████████| 13/13 [00:12<00:00,  1.00it/s]


Epoch 47 average loss: 0.3441659929230809 with testing loss of 0.02397289127111435
Epoch 48/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=0.318]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 48 average loss: 0.34081867430359125 with testing loss of 0.025392809882760048
Epoch 49/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=0.358]
100%|██████████| 13/13 [00:12<00:00,  1.07it/s]


Epoch 49 average loss: 0.3378031440079212 with testing loss of 0.02538568712770939
Epoch 50/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.341]
100%|██████████| 13/13 [00:13<00:00,  1.05s/it]


Epoch 50 average loss: 0.33470576256513596 with testing loss of 0.02538434974849224
Epoch 51/100


100%|██████████| 64/64 [01:11<00:00,  1.12s/it, loss=0.32] 
100%|██████████| 13/13 [00:12<00:00,  1.07it/s]


Epoch 51 average loss: 0.3317028833553195 with testing loss of 0.025348668918013573
Epoch 52/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.361]
100%|██████████| 13/13 [00:11<00:00,  1.09it/s]


Epoch 52 average loss: 0.32902099238708615 with testing loss of 0.026867913082242012
Epoch 53/100


100%|██████████| 64/64 [01:06<00:00,  1.03s/it, loss=0.331]
100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 53 average loss: 0.326192082837224 with testing loss of 0.022885456681251526
Epoch 54/100


100%|██████████| 64/64 [01:11<00:00,  1.12s/it, loss=0.331]
100%|██████████| 13/13 [00:13<00:00,  1.03s/it]


Epoch 54 average loss: 0.3235684260725975 with testing loss of 0.02813788689672947
Epoch 55/100


100%|██████████| 64/64 [01:13<00:00,  1.14s/it, loss=0.333]
100%|██████████| 13/13 [00:13<00:00,  1.04s/it]


Epoch 55 average loss: 0.3210467058233917 with testing loss of 0.024693498387932777
Epoch 56/100


100%|██████████| 64/64 [01:11<00:00,  1.12s/it, loss=0.324]
100%|██████████| 13/13 [00:13<00:00,  1.03s/it]


Epoch 56 average loss: 0.3185771689750254 with testing loss of 0.024923188611865044
Epoch 57/100


100%|██████████| 64/64 [01:11<00:00,  1.12s/it, loss=0.311]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 57 average loss: 0.3161795726045966 with testing loss of 0.024650687351822853
Epoch 58/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=0.305]
100%|██████████| 13/13 [00:13<00:00,  1.03s/it]


Epoch 58 average loss: 0.31389343412593007 with testing loss of 0.0227966271340847
Epoch 59/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.294]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 59 average loss: 0.3116674078628421 with testing loss of 0.022457556799054146
Epoch 60/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.31] 
100%|██████████| 13/13 [00:11<00:00,  1.09it/s]


Epoch 60 average loss: 0.3096131728962064 with testing loss of 0.02060733735561371
Epoch 61/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.314]
100%|██████████| 13/13 [00:12<00:00,  1.07it/s]


Epoch 61 average loss: 0.30759339313954115 with testing loss of 0.022606998682022095
Epoch 62/100


100%|██████████| 64/64 [01:06<00:00,  1.03s/it, loss=0.334]
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 62 average loss: 0.30569906951859593 with testing loss of 0.02816016785800457
Epoch 63/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.324]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 63 average loss: 0.30377297196537256 with testing loss of 0.02405899204313755
Epoch 64/100


100%|██████████| 64/64 [01:08<00:00,  1.07s/it, loss=0.329]
100%|██████████| 13/13 [00:12<00:00,  1.05it/s]


Epoch 64 average loss: 0.30196488089859486 with testing loss of 0.027864981442689896
Epoch 65/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=0.313]
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 65 average loss: 0.30014851642772555 with testing loss of 0.022232677787542343
Epoch 66/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=0.297]
100%|██████████| 13/13 [00:13<00:00,  1.02s/it]


Epoch 66 average loss: 0.2983942381106317 with testing loss of 0.028927553445100784
Epoch 67/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=0.303]
100%|██████████| 13/13 [00:12<00:00,  1.00it/s]


Epoch 67 average loss: 0.29677654756233096 with testing loss of 0.02209579199552536
Epoch 68/100


100%|██████████| 64/64 [01:11<00:00,  1.11s/it, loss=0.273]
100%|██████████| 13/13 [00:13<00:00,  1.02s/it]


Epoch 68 average loss: 0.2950883712619543 with testing loss of 0.025059126317501068
Epoch 69/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=0.282]
100%|██████████| 13/13 [00:13<00:00,  1.02s/it]


Epoch 69 average loss: 0.2935898401774466 with testing loss of 0.027731137350201607
Epoch 70/100


100%|██████████| 64/64 [01:13<00:00,  1.15s/it, loss=0.288]
100%|██████████| 13/13 [00:13<00:00,  1.04s/it]


Epoch 70 average loss: 0.2921348875388503 with testing loss of 0.024878578260540962
Epoch 71/100


100%|██████████| 64/64 [01:11<00:00,  1.11s/it, loss=0.298]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 71 average loss: 0.29074303759261966 with testing loss of 0.021474778652191162
Epoch 72/100


100%|██████████| 64/64 [01:11<00:00,  1.11s/it, loss=0.303]
100%|██████████| 13/13 [00:13<00:00,  1.02s/it]


Epoch 72 average loss: 0.28938569128513336 with testing loss of 0.02467629685997963
Epoch 73/100


100%|██████████| 64/64 [01:14<00:00,  1.16s/it, loss=0.32] 
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 73 average loss: 0.28811280708760023 with testing loss of 0.02628917247056961
Epoch 74/100


100%|██████████| 64/64 [01:12<00:00,  1.13s/it, loss=0.274]
100%|██████████| 13/13 [00:12<00:00,  1.00it/s]


Epoch 74 average loss: 0.28667292883619666 with testing loss of 0.02830479107797146
Epoch 75/100


100%|██████████| 64/64 [01:12<00:00,  1.13s/it, loss=0.266]
100%|██████████| 13/13 [00:13<00:00,  1.01s/it]


Epoch 75 average loss: 0.28540657786652446 with testing loss of 0.02218768373131752
Epoch 76/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=0.275]
100%|██████████| 13/13 [00:12<00:00,  1.03it/s]


Epoch 76 average loss: 0.2842394900508225 with testing loss of 0.020123744383454323
Epoch 77/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=0.263]
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 77 average loss: 0.2830415512435138 with testing loss of 0.018923351541161537
Epoch 78/100


100%|██████████| 64/64 [01:10<00:00,  1.10s/it, loss=0.253]
100%|██████████| 13/13 [00:13<00:00,  1.01s/it]


Epoch 78 average loss: 0.2818850197363645 with testing loss of 0.018572138622403145
Epoch 79/100


100%|██████████| 64/64 [01:11<00:00,  1.12s/it, loss=0.268]
100%|██████████| 13/13 [00:13<00:00,  1.05s/it]


Epoch 79 average loss: 0.2808552065398544 with testing loss of 0.021513137966394424
Epoch 80/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=0.265]
100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 80 average loss: 0.27979950816370547 with testing loss of 0.019360052421689034
Epoch 81/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=0.277]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 81 average loss: 0.27883631899021566 with testing loss of 0.024522261694073677
Epoch 82/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=0.251]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 82 average loss: 0.27776746288873255 with testing loss of 0.02044270932674408
Epoch 83/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=0.276]
100%|██████████| 13/13 [00:13<00:00,  1.00s/it]


Epoch 83 average loss: 0.27690772456116974 with testing loss of 0.024500755593180656
Epoch 84/100


100%|██████████| 64/64 [01:05<00:00,  1.03s/it, loss=0.316]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 84 average loss: 0.27613051887601614 with testing loss of 0.025309929624199867
Epoch 85/100


100%|██████████| 64/64 [01:05<00:00,  1.02s/it, loss=0.264]
100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


Epoch 85 average loss: 0.27507221954874694 with testing loss of 0.021433129906654358
Epoch 86/100


100%|██████████| 64/64 [01:05<00:00,  1.03s/it, loss=0.278]
100%|██████████| 13/13 [00:11<00:00,  1.09it/s]


Epoch 86 average loss: 0.2742618788033724 with testing loss of 0.02387085184454918
Epoch 87/100


100%|██████████| 64/64 [01:06<00:00,  1.03s/it, loss=0.249]
100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


Epoch 87 average loss: 0.27333381143398583 with testing loss of 0.02089138887822628
Epoch 88/100


100%|██████████| 64/64 [01:05<00:00,  1.03s/it, loss=0.278]
100%|██████████| 13/13 [00:12<00:00,  1.07it/s]


Epoch 88 average loss: 0.2726305259857327 with testing loss of 0.020257649943232536
Epoch 89/100


100%|██████████| 64/64 [01:06<00:00,  1.03s/it, loss=0.28] 
100%|██████████| 13/13 [00:13<00:00,  1.02s/it]


Epoch 89 average loss: 0.27186522516421974 with testing loss of 0.02514961175620556
Epoch 90/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.281]
100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


Epoch 90 average loss: 0.2711120645981282 with testing loss of 0.017853524535894394
Epoch 91/100


100%|██████████| 64/64 [01:08<00:00,  1.08s/it, loss=0.264]
100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 91 average loss: 0.27032613893970847 with testing loss of 0.02254546247422695
Epoch 92/100


100%|██████████| 64/64 [01:07<00:00,  1.06s/it, loss=0.255]
100%|██████████| 13/13 [00:11<00:00,  1.09it/s]


Epoch 92 average loss: 0.2695917545352131 with testing loss of 0.016232723370194435
Epoch 93/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.285]
100%|██████████| 13/13 [00:12<00:00,  1.02it/s]


Epoch 93 average loss: 0.269007851369679 with testing loss of 0.020157385617494583
Epoch 94/100


100%|██████████| 64/64 [01:09<00:00,  1.09s/it, loss=0.261]
100%|██████████| 13/13 [00:12<00:00,  1.01it/s]


Epoch 94 average loss: 0.2682656121905893 with testing loss of 0.01946147345006466
Epoch 95/100


100%|██████████| 64/64 [01:09<00:00,  1.08s/it, loss=0.241]
100%|██████████| 13/13 [00:11<00:00,  1.09it/s]


Epoch 95 average loss: 0.26755376579239964 with testing loss of 0.0251374039798975
Epoch 96/100


100%|██████████| 64/64 [01:07<00:00,  1.05s/it, loss=0.255]
100%|██████████| 13/13 [00:12<00:00,  1.08it/s]


Epoch 96 average loss: 0.2669812529347837 with testing loss of 0.022934790700674057
Epoch 97/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.272]
100%|██████████| 13/13 [00:12<00:00,  1.08it/s]


Epoch 97 average loss: 0.2664372813887894 with testing loss of 0.023356301710009575
Epoch 98/100


100%|██████████| 64/64 [01:06<00:00,  1.03s/it, loss=0.25] 
100%|██████████| 13/13 [00:12<00:00,  1.06it/s]


Epoch 98 average loss: 0.265776411164552 with testing loss of 0.017748035490512848
Epoch 99/100


100%|██████████| 64/64 [01:06<00:00,  1.04s/it, loss=0.263]
100%|██████████| 13/13 [00:11<00:00,  1.10it/s]


Epoch 99 average loss: 0.2652573613449931 with testing loss of 0.020988788455724716
Epoch 100/100


100%|██████████| 64/64 [01:06<00:00,  1.03s/it, loss=0.29] 
100%|██████████| 13/13 [00:12<00:00,  1.04it/s]


Epoch 100 average loss: 0.26479940698482096 with testing loss of 0.016836050897836685
Training complete!


In [9]:
state.keys()

dict_keys(['loss', 'epochs', 'test_loss'])

In [13]:
plt.scatter()

[tensor(0.0524),
 tensor(0.0512),
 tensor(0.0510),
 tensor(0.0500),
 tensor(0.0509),
 tensor(0.0472),
 tensor(0.0469),
 tensor(0.0469),
 tensor(0.0446),
 tensor(0.0439),
 tensor(0.0438),
 tensor(0.0455),
 tensor(0.0415),
 tensor(0.0419),
 tensor(0.0395),
 tensor(0.0433),
 tensor(0.0422),
 tensor(0.0380),
 tensor(0.0374),
 tensor(0.0418),
 tensor(0.0372),
 tensor(0.0384),
 tensor(0.0354),
 tensor(0.0375),
 tensor(0.0356),
 tensor(0.0370),
 tensor(0.0337),
 tensor(0.0324),
 tensor(0.0329),
 tensor(0.0333),
 tensor(0.0308),
 tensor(0.0295),
 tensor(0.0303),
 tensor(0.0294),
 tensor(0.0329),
 tensor(0.0256),
 tensor(0.0320),
 tensor(0.0325),
 tensor(0.0275),
 tensor(0.0286),
 tensor(0.0241),
 tensor(0.0277),
 tensor(0.0279),
 tensor(0.0264),
 tensor(0.0270),
 tensor(0.0286),
 tensor(0.0240),
 tensor(0.0254),
 tensor(0.0254),
 tensor(0.0254),
 tensor(0.0253),
 tensor(0.0269),
 tensor(0.0229),
 tensor(0.0281),
 tensor(0.0247),
 tensor(0.0249),
 tensor(0.0247),
 tensor(0.0228),
 tensor(0.0225