# Train

## Import Libraries

In [5]:
# torch
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.utils.data import Dataset, DataLoader

from tqdm import trange
import numpy as np

# import custom modules
from UNet import UNet

# config device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu_device = torch.device("cpu")

## Load Data

In [13]:
# load data from npy files
class TrainImageDataset(Dataset):
    def __init__(self, device: torch.DeviceObjType) -> None:
        super().__init__()
        self.device = device

        low_resolution_x = np.load("./Data/low_resolution_x.npy")
        low_resolution_y = np.load("./Data/low_resolution_y.npy")
        high_resolution_x = np.load("./Data/high_resolution_x.npy")
        high_resolution_y = np.load("./Data/high_resolution_y.npy")

        self.train_x = np.concatenate((low_resolution_x, high_resolution_x), axis=0)
        self.train_y = np.concatenate((low_resolution_y, high_resolution_y), axis=0)
    
    def __len__(self):
        return self.train_x.shape[0]
    
    def __getitem__(self, idx):
        x_i = torch.from_numpy(self.train_x[idx]).to(self.device).unsqueeze(0) # 添加channel dim
        y_i = torch.from_numpy(self.train_y[idx]).to(self.device).unsqueeze(0)
        return x_i, y_i

train_dataset = TrainImageDataset(device)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

## Train Model

In [14]:
# 训练模型
model = UNet(in_channels=1, out_channels=1, bilinear=True).to(device)
optimizer = Adam(model.parameters(), lr=1e-3)

epochs = 100
with trange(epochs, desc="Model Training") as t:
    postfix = {}
    for epoch in t:
        for batch_x, batch_y in train_dataloader:
            optimizer.zero_grad()
            pred = model(batch_x)
            loss = F.cross_entropy(pred, batch_y)
            loss.backward()
            optimizer.step()
            
            postfix["loss": loss.detach().to(cpu_device).item()]
            t.set_postfix(postfix)

torch.Size([64, 1, 112, 112])
torch.Size([64, 1, 112, 112])
