In [5]:
import cv2
import os
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
import torch
from tqdm import tqdm
import torch.optim as optim
import torch.nn as nn
from model import HPCFNet
from data.processing import make_dataset

In [6]:
make_dataset()
dataset = torch.load("dataset.pt")
model = HPCFNet(32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset, valid_dataset = random_split(dataset, [90, 10])

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=2, shuffle=True)


100%|██████████| 100/100 [00:00<00:00, 141.56it/s]


In [7]:
def train(device, model, train_loader, valid_loader, epochs = 10, lr = 0.001):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr = lr)
    criterion = nn.CrossEntropyLoss()
    train_losses = []
    valid_losses = []
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        for inputs, labels in tqdm(train_loader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        epoch_loss /= len(train_loader)
        train_losses.append(epoch_loss)
        print(f'--- Epoch {epoch+1}/{epochs}: Train loss: {epoch_loss:.4f}')

        model.eval()
        epoch_loss = 0.0
        for inputs, labels in valid_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            with torch.no_grad():
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            epoch_loss += loss.item()
        epoch_loss /= len(valid_loader)
        valid_losses.append(epoch_loss)
        print(f'--- Epoch {epoch+1}/{epochs}: valid loss: {epoch_loss:.4f}')
        try:
            os.makedirs(".\\state_dict")
        except:
            pass
        torch.save(model.state_dict(), ".\\state_dict\\{}.pt".format(epoch))
    return train_losses, valid_losses



In [8]:
train(device, model, train_loader, valid_loader, 100)

  7%|▋         | 3/45 [00:01<00:16,  2.49it/s]


KeyboardInterrupt: 

In [10]:
def view_output(i):
    img = dataset[i][0].to(device)
    img = img.unsqueeze(0)
    output = model(img)
    _, predicted = torch.max(output.data, 1)
    a = predicted.int().cpu().numpy().transpose(2, 1, 0)*255
    cv2.imwrite("{}.jpg".format(device), a)
