# Train

## Import Libraries

In [1]:
# torch
import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
from torch.utils.data import Dataset, DataLoader

from tqdm import trange
import numpy as np
import os

# import custom modules
from UNet import UNet, Nested_UNet
from utils import load_zipped_pickle, save_zipped_pickle

# config device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cpu_device = torch.device("cpu")

## Load Data

In [2]:
# load data from npy files (original data)
class TrainImageDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        low_resolution_x = np.load("./Data/low_resolution_x.npy")
        low_resolution_y = np.load("./Data/low_resolution_y.npy")
        high_resolution_x = np.load("./Data/high_resolution_x.npy")
        high_resolution_y = np.load("./Data/high_resolution_y.npy")

        self.train_x = np.concatenate((low_resolution_x, high_resolution_x), axis=0)
        self.train_y = np.concatenate((low_resolution_y, high_resolution_y), axis=0)
    
    def __len__(self):
        return self.train_x.shape[0]
    
    def __getitem__(self, idx):
        x_i = torch.from_numpy(self.train_x[idx]).float().unsqueeze(0) # 添加channel dim，同时要注意转换数据类型
        y_i = torch.from_numpy(self.train_y[idx]).float().unsqueeze(0)
        return x_i, y_i

train_dataset = TrainImageDataset()
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

In [2]:
# load data from npy files (original data) (只有高维数据图像)
class TrainImageDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.train_x = np.load("./Data/high_resolution_x_256.npy")
        self.train_y = np.load("./Data/high_resolution_y_256.npy")
    
    def __len__(self):
        return self.train_x.shape[0]
    
    def __getitem__(self, idx):
        x_i = torch.from_numpy(self.train_x[idx]).float().unsqueeze(0) # 添加channel dim，同时要注意转换数据类型
        y_i = torch.from_numpy(self.train_y[idx]).float().unsqueeze(0)
        return x_i, y_i

train_dataset = TrainImageDataset()
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
# load data from npy files (original data) (只有高维数据图像)
class TrainImageDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.train_x = np.load("./Data/high_resolution_x_256.npy")
        self.train_y = np.load("./Data/high_resolution_y_256.npy")
    
    def __len__(self):
        return self.train_x.shape[0]
    
    def __getitem__(self, idx):
        x_i = torch.from_numpy(self.train_x[idx]).float().unsqueeze(0) # 添加channel dim，同时要注意转换数据类型
        y_i = torch.from_numpy(self.train_y[idx]).float().unsqueeze(0)
        return x_i, y_i

train_dataset = TrainImageDataset()
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [2]:
# load data from npy files (multiple channels data)
class TrainImageDataset(Dataset):
    def __init__(self) -> None:
        super().__init__()
        low_resolution_x = np.load("./Data/low_resolution_x_multi_channels.npy")
        low_resolution_y = np.load("./Data/low_resolution_y_multi_channels.npy")
        high_resolution_x = np.load("./Data/high_resolution_x_multi_channels.npy")
        high_resolution_y = np.load("./Data/high_resolution_y_multi_channels.npy")

        self.train_x = np.concatenate((low_resolution_x, high_resolution_x), axis=0)
        self.train_y = np.concatenate((low_resolution_y, high_resolution_y), axis=0)
    
    def __len__(self):
        return self.train_x.shape[0]
    
    def __getitem__(self, idx):
        x_i = torch.from_numpy(self.train_x[idx]).float()   # 不需要添加channel dim，要注意转换数据类型
        y_i = torch.from_numpy(self.train_y[idx]).float().unsqueeze(0)
        return x_i, y_i

train_dataset = TrainImageDataset()
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)

In [2]:
# load data from npy files (preprocessed data)
class TrainImageDatasetProcessed(Dataset):
    def __init__(self) -> None:
        super().__init__()
        self.train_x = np.load("./Data/train_x.npy")
        self.train_y = np.load("./Data/train_y.npy")
        self.train_x_samples = []
        self.train_y_samples = []
        for img in self.train_x:
            self.train_x_samples += self.sample(img, (256, 256), stride=32)
        
        for img in self.train_y:
            self.train_y_samples += self.sample(img, (256, 256), stride=32)

    
    def __len__(self):
        return len(self.train_x_samples)
    
    def __getitem__(self, idx):
        x_i = torch.from_numpy(self.train_x_samples[idx]).float().unsqueeze(0) # 添加channel dim，同时要注意转换数据类型
        y_i = torch.from_numpy(self.train_y_samples[idx]).float().unsqueeze(0)
        return x_i, y_i
    
    def sample(self, img: np.ndarray, window_size: tuple, stride: int):
        samples = []
        for i in range(0, img.shape[0] - window_size[0] + 1, stride):
            for j in range(0, img.shape[1] - window_size[1] + 1, stride):
                samples.append(img[i:i + window_size[0], j:j + window_size[1]])
        return samples

train_dataset = TrainImageDatasetProcessed()
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

## Train Model

### UNet 训练

In [3]:
# 训练UNet(bilinear)模型 (预计用时: 1h15mins) (batchsize可以设置为128)
model = UNet(in_channels=7, out_channels=1, bilinear=True).to(device)
criterion = nn.BCEWithLogitsLoss() # 不需要单独计算sigmoid，最后在预测的时候需要用sigmoid
optimizer = Adam(model.parameters(), lr=1e-3)

epochs = 100
loss_history = []
with trange(epochs, desc="Model Training") as t:
    postfix = {}
    for epoch in t:
        for batch_i, (batch_x, batch_y) in enumerate(train_dataloader):
            optimizer.zero_grad()
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device) 
            pred_y = model(batch_x)
            loss = criterion(pred_y, batch_y)
            loss.backward()
            optimizer.step()
            
            check_loss = loss.to(cpu_device).detach().item()
            loss_history.append(check_loss)
            postfix["batch id"] = (batch_i + 1) * train_dataloader.batch_size
            postfix["loss"] = check_loss
            t.set_postfix(postfix)

# 保存模型
torch.save(model.state_dict(), "./Model/UNet_multi_channels.pt")

Model Training: 100%|██████████| 100/100 [2:18:59<00:00, 83.39s/it, batch id=39680, loss=1.91e-5]  


In [3]:
# 训练UNet(ConvTranspose)模型 (预计用时: 1h15mins) (batchsize可以设置为128)
model = UNet(in_channels=1, out_channels=1, bilinear=False).to(device)
criterion = nn.BCEWithLogitsLoss() # 不需要单独计算sigmoid，最后在预测的时候需要用sigmoid
optimizer = Adam(model.parameters(), lr=1e-3)

epochs = 105
loss_history = []
model.train()
with trange(epochs, desc="Model Training") as t:
    postfix = {}
    for epoch in t:
        for batch_i, (batch_x, batch_y) in enumerate(train_dataloader):
            optimizer.zero_grad()
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            pred_y = model(batch_x)
            loss = criterion(pred_y, batch_y)
            loss.backward()
            optimizer.step()
            
            check_loss = loss.to(cpu_device).detach().item()
            loss_history.append(check_loss)
            postfix["batch id"] = (batch_i + 1) * train_dataloader.batch_size
            postfix["loss"] = check_loss
            t.set_postfix(postfix)

# 保存模型
torch.save(model.state_dict(), "./Model/UNet_4_4_transpose_conv.pt")

Model Training: 100%|██████████| 105/105 [1:16:19<00:00, 43.62s/it, batch id=23296, loss=1.15e-5] 


In [3]:
# 训练UNet++模型 (预计用时: 1h15mins) (batchsize可以设置为64)
model = Nested_UNet(in_channels=1, out_channels=1).to(device)
criterion = nn.BCEWithLogitsLoss() # 不需要单独计算sigmoid，最后在预测的时候需要用sigmoid
optimizer = Adam(model.parameters(), lr=1e-3)

epochs = 100
loss_history = []
model.train()
with trange(epochs, desc="Model Training") as t:
    postfix = {}
    for epoch in t:
        for batch_i, (batch_x, batch_y) in enumerate(train_dataloader):
            optimizer.zero_grad()
            batch_x = batch_x.to(device)
            batch_y = batch_y.to(device)
            pred_y = model(batch_x)
            loss = criterion(pred_y, batch_y)
            loss.backward()
            optimizer.step()
            
            check_loss = loss.to(cpu_device).detach().item()
            loss_history.append(check_loss)
            postfix["batch id"] = (batch_i + 1) * train_dataloader.batch_size
            postfix["loss"] = check_loss
            t.set_postfix(postfix)

# 保存模型
torch.save(model.state_dict(), "./Model/NestedUNet_4_4.pt")

Model Training: 100%|██████████| 100/100 [3:41:00<00:00, 132.60s/it, batch id=23296, loss=9.48e-5]  
