In [1]:
import cv2
import os
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import random_split
import torch
import torch.nn.functional as F
from tqdm import tqdm
from data.dataset import TSUNAMIDataset
import torch.optim as optim
import torch.nn as nn
from model import HPCFNet_tiny, HPCFNet_small
from data.processing import make_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
make_dataset()

In [2]:
dataset = torch.load("dataset.pt")
model = HPCFNet_small()

train_dataset, valid_dataset = random_split(dataset, [90, 10])

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=True)


In [3]:
def train(model, train_loader, valid_loader, epochs = 10, lr = 0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr = lr)
    criterion = nn.CrossEntropyLoss()
    train_losses = []
    valid_losses = []
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        for inputs, labels in tqdm(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= len(train_loader)
        train_losses.append(epoch_loss)
        print(f'--- Epoch {epoch+1}/{epochs}: Train loss: {epoch_loss:.4f}')

        model.eval()
        epoch_loss = 0.0
        for inputs, labels in valid_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            epoch_loss += loss.item()
        epoch_loss /= len(valid_loader)
        valid_losses.append(epoch_loss)
        print(f'--- Epoch {epoch+1}/{epochs}: valid loss: {epoch_loss:.4f}')
        try:
            os.makedirs(".\\state_dict")
        except:
            pass
        torch.save(model.state_dict(), ".\\state_dict\\{}.pt".format(epoch))
    return train_losses, valid_losses



In [4]:
train(model, train_loader, valid_loader, 100)

100%|██████████| 45/45 [00:15<00:00,  2.94it/s]


--- Epoch 1/100: Train loss: 0.5033
--- Epoch 1/100: valid loss: 0.4848


100%|██████████| 45/45 [00:12<00:00,  3.59it/s]


--- Epoch 2/100: Train loss: 0.4621
--- Epoch 2/100: valid loss: 0.4584


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 3/100: Train loss: 0.4417
--- Epoch 3/100: valid loss: 0.4628


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 4/100: Train loss: 0.4290
--- Epoch 4/100: valid loss: 0.4344


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 5/100: Train loss: 0.4060
--- Epoch 5/100: valid loss: 0.4373


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 6/100: Train loss: 0.3915
--- Epoch 6/100: valid loss: 0.4279


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 7/100: Train loss: 0.3789
--- Epoch 7/100: valid loss: 0.4028


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 8/100: Train loss: 0.3762
--- Epoch 8/100: valid loss: 0.4335


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 9/100: Train loss: 0.3604
--- Epoch 9/100: valid loss: 0.3820


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 10/100: Train loss: 0.3431
--- Epoch 10/100: valid loss: 0.3993


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 11/100: Train loss: 0.3371
--- Epoch 11/100: valid loss: 0.3955


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 12/100: Train loss: 0.3183
--- Epoch 12/100: valid loss: 0.4664


100%|██████████| 45/45 [00:12<00:00,  3.55it/s]


--- Epoch 13/100: Train loss: 0.3147
--- Epoch 13/100: valid loss: 0.3698


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 14/100: Train loss: 0.3045
--- Epoch 14/100: valid loss: 0.3735


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 15/100: Train loss: 0.2986
--- Epoch 15/100: valid loss: 0.3765


100%|██████████| 45/45 [00:12<00:00,  3.55it/s]


--- Epoch 16/100: Train loss: 0.2844
--- Epoch 16/100: valid loss: 0.3667


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 17/100: Train loss: 0.2854
--- Epoch 17/100: valid loss: 0.3472


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 18/100: Train loss: 0.2784
--- Epoch 18/100: valid loss: 0.3256


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 19/100: Train loss: 0.2695
--- Epoch 19/100: valid loss: 0.3425


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 20/100: Train loss: 0.2642
--- Epoch 20/100: valid loss: 0.3563


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 21/100: Train loss: 0.2539
--- Epoch 21/100: valid loss: 0.3208


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 22/100: Train loss: 0.2502
--- Epoch 22/100: valid loss: 0.3576


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 23/100: Train loss: 0.2447
--- Epoch 23/100: valid loss: 0.3171


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 24/100: Train loss: 0.2383
--- Epoch 24/100: valid loss: 0.3105


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 25/100: Train loss: 0.2374
--- Epoch 25/100: valid loss: 0.3611


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 26/100: Train loss: 0.2341
--- Epoch 26/100: valid loss: 0.3660


100%|██████████| 45/45 [00:12<00:00,  3.55it/s]


--- Epoch 27/100: Train loss: 0.2266
--- Epoch 27/100: valid loss: 0.3625


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 28/100: Train loss: 0.2280
--- Epoch 28/100: valid loss: 0.3204


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 29/100: Train loss: 0.2175
--- Epoch 29/100: valid loss: 0.3086


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 30/100: Train loss: 0.2197
--- Epoch 30/100: valid loss: 0.3453


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 31/100: Train loss: 0.2108
--- Epoch 31/100: valid loss: 0.2921


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 32/100: Train loss: 0.2132
--- Epoch 32/100: valid loss: 0.3475


100%|██████████| 45/45 [00:12<00:00,  3.46it/s]


--- Epoch 33/100: Train loss: 0.2192
--- Epoch 33/100: valid loss: 0.3645


100%|██████████| 45/45 [00:13<00:00,  3.45it/s]


--- Epoch 34/100: Train loss: 0.1985
--- Epoch 34/100: valid loss: 0.2852


100%|██████████| 45/45 [00:13<00:00,  3.45it/s]


--- Epoch 35/100: Train loss: 0.1990
--- Epoch 35/100: valid loss: 0.2887


100%|██████████| 45/45 [00:13<00:00,  3.46it/s]


--- Epoch 36/100: Train loss: 0.1861
--- Epoch 36/100: valid loss: 0.2841


100%|██████████| 45/45 [00:12<00:00,  3.46it/s]


--- Epoch 37/100: Train loss: 0.1887
--- Epoch 37/100: valid loss: 0.3221


100%|██████████| 45/45 [00:12<00:00,  3.47it/s]


--- Epoch 38/100: Train loss: 0.1832
--- Epoch 38/100: valid loss: 0.3292


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 39/100: Train loss: 0.1834
--- Epoch 39/100: valid loss: 0.3189


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 40/100: Train loss: 0.1797
--- Epoch 40/100: valid loss: 0.3293


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 41/100: Train loss: 0.1782
--- Epoch 41/100: valid loss: 0.3015


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 42/100: Train loss: 0.1796
--- Epoch 42/100: valid loss: 0.2767


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 43/100: Train loss: 0.1753
--- Epoch 43/100: valid loss: 0.2910


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 44/100: Train loss: 0.1702
--- Epoch 44/100: valid loss: 0.2914


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 45/100: Train loss: 0.1623
--- Epoch 45/100: valid loss: 0.2922


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 46/100: Train loss: 0.1579
--- Epoch 46/100: valid loss: 0.3907


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 47/100: Train loss: 0.1685
--- Epoch 47/100: valid loss: 0.3146


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 48/100: Train loss: 0.1599
--- Epoch 48/100: valid loss: 0.3307


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 49/100: Train loss: 0.1598
--- Epoch 49/100: valid loss: 0.3361


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 50/100: Train loss: 0.1595
--- Epoch 50/100: valid loss: 0.3274


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 51/100: Train loss: 0.1515
--- Epoch 51/100: valid loss: 0.2726


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 52/100: Train loss: 0.1459
--- Epoch 52/100: valid loss: 0.3201


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 53/100: Train loss: 0.1437
--- Epoch 53/100: valid loss: 0.2889


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 54/100: Train loss: 0.1398
--- Epoch 54/100: valid loss: 0.2753


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 55/100: Train loss: 0.1416
--- Epoch 55/100: valid loss: 0.3592


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 56/100: Train loss: 0.1422
--- Epoch 56/100: valid loss: 0.2807


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 57/100: Train loss: 0.1343
--- Epoch 57/100: valid loss: 0.3138


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 58/100: Train loss: 0.1411
--- Epoch 58/100: valid loss: 0.3232


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 59/100: Train loss: 0.1405
--- Epoch 59/100: valid loss: 0.2812


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 60/100: Train loss: 0.1333
--- Epoch 60/100: valid loss: 0.3123


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 61/100: Train loss: 0.1280
--- Epoch 61/100: valid loss: 0.3248


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 62/100: Train loss: 0.1294
--- Epoch 62/100: valid loss: 0.3110


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 63/100: Train loss: 0.1303
--- Epoch 63/100: valid loss: 0.3411


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 64/100: Train loss: 0.1232
--- Epoch 64/100: valid loss: 0.2966


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 65/100: Train loss: 0.1203
--- Epoch 65/100: valid loss: 0.3458


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 66/100: Train loss: 0.1216
--- Epoch 66/100: valid loss: 0.3255


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 67/100: Train loss: 0.1255
--- Epoch 67/100: valid loss: 0.2868


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 68/100: Train loss: 0.1197
--- Epoch 68/100: valid loss: 0.2861


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 69/100: Train loss: 0.1165
--- Epoch 69/100: valid loss: 0.3111


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 70/100: Train loss: 0.1255
--- Epoch 70/100: valid loss: 0.3690


100%|██████████| 45/45 [00:12<00:00,  3.55it/s]


--- Epoch 71/100: Train loss: 0.1203
--- Epoch 71/100: valid loss: 0.3616


100%|██████████| 45/45 [00:12<00:00,  3.55it/s]


--- Epoch 72/100: Train loss: 0.1189
--- Epoch 72/100: valid loss: 0.3061


100%|██████████| 45/45 [00:12<00:00,  3.55it/s]


--- Epoch 73/100: Train loss: 0.1095
--- Epoch 73/100: valid loss: 0.3072


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 74/100: Train loss: 0.1070
--- Epoch 74/100: valid loss: 0.3000


100%|██████████| 45/45 [00:12<00:00,  3.52it/s]


--- Epoch 75/100: Train loss: 0.1049
--- Epoch 75/100: valid loss: 0.3168


100%|██████████| 45/45 [00:13<00:00,  3.42it/s]


--- Epoch 76/100: Train loss: 0.1044
--- Epoch 76/100: valid loss: 0.3173


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 77/100: Train loss: 0.1030
--- Epoch 77/100: valid loss: 0.3467


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 78/100: Train loss: 0.1035
--- Epoch 78/100: valid loss: 0.3171


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 79/100: Train loss: 0.1033
--- Epoch 79/100: valid loss: 0.3280


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 80/100: Train loss: 0.0980
--- Epoch 80/100: valid loss: 0.3173


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 81/100: Train loss: 0.0965
--- Epoch 81/100: valid loss: 0.3085


100%|██████████| 45/45 [00:12<00:00,  3.55it/s]


--- Epoch 82/100: Train loss: 0.0972
--- Epoch 82/100: valid loss: 0.3345


100%|██████████| 45/45 [00:12<00:00,  3.53it/s]


--- Epoch 83/100: Train loss: 0.0961
--- Epoch 83/100: valid loss: 0.3076


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 84/100: Train loss: 0.0950
--- Epoch 84/100: valid loss: 0.3151


100%|██████████| 45/45 [00:12<00:00,  3.55it/s]


--- Epoch 85/100: Train loss: 0.0916
--- Epoch 85/100: valid loss: 0.3308


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 86/100: Train loss: 0.0899
--- Epoch 86/100: valid loss: 0.3189


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 87/100: Train loss: 0.0879
--- Epoch 87/100: valid loss: 0.3400


100%|██████████| 45/45 [00:12<00:00,  3.57it/s]


--- Epoch 88/100: Train loss: 0.0883
--- Epoch 88/100: valid loss: 0.3469


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 89/100: Train loss: 0.0906
--- Epoch 89/100: valid loss: 0.3451


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 90/100: Train loss: 0.0906
--- Epoch 90/100: valid loss: 0.3814


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 91/100: Train loss: 0.0932
--- Epoch 91/100: valid loss: 0.3629


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 92/100: Train loss: 0.0879
--- Epoch 92/100: valid loss: 0.3008


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 93/100: Train loss: 0.0833
--- Epoch 93/100: valid loss: 0.3277


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 94/100: Train loss: 0.0830
--- Epoch 94/100: valid loss: 0.3765


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 95/100: Train loss: 0.0802
--- Epoch 95/100: valid loss: 0.3364


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 96/100: Train loss: 0.0811
--- Epoch 96/100: valid loss: 0.3473


100%|██████████| 45/45 [00:12<00:00,  3.56it/s]


--- Epoch 97/100: Train loss: 0.0809
--- Epoch 97/100: valid loss: 0.3678


100%|██████████| 45/45 [00:12<00:00,  3.55it/s]


--- Epoch 98/100: Train loss: 0.0776
--- Epoch 98/100: valid loss: 0.3675


100%|██████████| 45/45 [00:12<00:00,  3.54it/s]


--- Epoch 99/100: Train loss: 0.0801
--- Epoch 99/100: valid loss: 0.4346


100%|██████████| 45/45 [00:12<00:00,  3.50it/s]


--- Epoch 100/100: Train loss: 0.0867
--- Epoch 100/100: valid loss: 0.5099


([0.5032904313670264,
  0.4621398660871718,
  0.4416995512114631,
  0.42897841665479874,
  0.40599270794126724,
  0.3914594007862939,
  0.37888173527187774,
  0.3761853198210398,
  0.36039407054583233,
  0.34307861394352385,
  0.3370922048886617,
  0.31833916041586136,
  0.3146543006102244,
  0.3045283675193787,
  0.2986345324251387,
  0.2844205445713467,
  0.2853945622841517,
  0.27843518786960175,
  0.26947075956397587,
  0.26422218018107946,
  0.25392633080482485,
  0.250242132279608,
  0.24466769728395674,
  0.23832402659787072,
  0.23735981384913127,
  0.2340998606549369,
  0.22662045823203192,
  0.2279953615532981,
  0.21746777958340116,
  0.21974430978298187,
  0.21076729132069483,
  0.2131688680913713,
  0.21919376684559716,
  0.19849019315507677,
  0.19901009036435022,
  0.186129958430926,
  0.18874241941505007,
  0.18320203721523284,
  0.18340885837872822,
  0.1796521094110277,
  0.17817179759343466,
  0.1795615126689275,
  0.1752854550878207,
  0.17024667263031007,
  0.16229