In [1]:
from pathlib import Path
import json
import os
import random
import sys
import time
import numpy as np
import matplotlib.pyplot as plt

from torch import nn, optim
from torchvision import models, datasets, transforms
import torch
import torchvision
import torch.nn.functional as F
from voxel_data_generator import SSL_Dataset

In [2]:
def _weights_init(m):  # 權重初始化
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv3d):
        nn.init.kaiming_normal_(m.weight)

In [3]:
class ConvNet_module(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv3d(1, 64, kernel_size=3, stride=1, bias=False, padding=1)
        self.conv2 = nn.Conv3d(64, 128, kernel_size=3, stride=1, bias=False, padding=1)
        self.conv3 = nn.Conv3d(128, 256, kernel_size=3, stride=1, bias=False, padding=1)
        self.conv4 = nn.Conv3d(256, 512, kernel_size=3, stride=1, bias=False, padding=1)
        self.conv5 = nn.Conv3d(512, 1024, kernel_size=3, stride=1, bias=False, padding=1)
        self.bn1 = nn.BatchNorm3d(64)
        self.bn2 = nn.BatchNorm3d(128)
        self.bn3 = nn.BatchNorm3d(256)
        self.bn4 = nn.BatchNorm3d(512)
        self.bn5 = nn.BatchNorm3d(1024)
        self.pool1 = nn.MaxPool3d(2)
        self.pool2 = nn.MaxPool3d(2)
        self.pool3 = nn.MaxPool3d(2)
        self.pool4 = nn.MaxPool3d(2)
        self.pool5 = nn.AvgPool3d(4)
        self.apply(_weights_init)
        self.fc = nn.Sequential(
                                  nn.Linear(1024, 128),
                                  nn.ReLU(),
                                  nn.Dropout(0.25),
                                  nn.Linear(128, 1),
                                )

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool1(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool2(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = self.pool3(x)
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool4(x)
        x = F.relu(self.bn5(self.conv5(x)))
        x = self.pool5(x)
        x = torch.squeeze(x)
        x = self.fc(x)
        return x

In [4]:
def main():
    model = ConvNet_module().cuda()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=3e-3)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 200)
    train_dataset = SSL_Dataset(train=True)
    val_dataset = SSL_Dataset(train=False)
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=64, num_workers=2,
        pin_memory=True, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, num_workers=2)
    
    start_time = time.time()
    best_val = 10
    for epoch in range(200):
        model.train()
        for step, (images, target) in enumerate(train_loader, start=epoch * len(train_loader)):
            output = model(images.cuda(non_blocking=True))
            loss = criterion(output, target.cuda(non_blocking=True).float())
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if step % 5 == 0:
                pg = optimizer.param_groups
                stats = dict(epoch=epoch, step=step, loss=loss.item(),
                             time=int(time.time() - start_time))
                print(stats)
        model.eval()
        for inputs, labels in val_loader:
            y_pred = model(inputs.cuda())
            loss = criterion(y_pred, labels.cuda(non_blocking=True).float())
            print('val_loss', loss.cpu().detach().numpy())
            if loss.item() < best_val:
                best_val = loss.item()
                torch.save(model.state_dict(),'bumpSup_0922.pth')
                
    print(best_val)
    state = torch.load('bumpSup_0922.pth', map_location='cpu')
    model.load_state_dict(state, strict=False)
    val_set(val_loader, model)
    
    

def val_set(dataset, model):
    test_loader = dataset
    deviation = 0.0
    mean, sigma = 0, 1
    y, l = np.array([]), np.array([])
    s = 0
    for inputs, labels in test_loader:
            y_pred = model(inputs.cuda())
            y_pred, labels = ((y_pred.cpu().detach().numpy().reshape(-1, 1)*sigma)+mean), ((labels.cpu().detach().numpy().reshape(-1, 1)*sigma)+mean)
            y = np.append(y, y_pred)
            l = np.append(l, labels.reshape(-1, 1))
            s = plt.scatter(labels, y_pred, c='red',alpha=0.5)
    plt.plot([-2, 2.2], [-2, 2.2], c='black', ls='--')
#     plt.text(0.0125/1.5*1000, 0.0155/1.5*1000, 'RMSE='+str(round(np.sqrt(mse), 4))+'(mΩ)',fontsize=12)
    plt.ylabel('predition', fontsize=18)
    plt.xlabel('ground truth', fontsize=18)
    plt.show()
    plt.close()

In [None]:
if __name__ == '__main__':
    main()

{'epoch': 0, 'step': 0, 'loss': 1.0525420904159546, 'time': 2}
val_loss 0.69510007
{'epoch': 1, 'step': 5, 'loss': 0.5866701602935791, 'time': 6}
val_loss 0.77658874
{'epoch': 2, 'step': 10, 'loss': 0.3845071792602539, 'time': 9}
val_loss 1.1991413
{'epoch': 3, 'step': 15, 'loss': 0.4072030186653137, 'time': 12}
val_loss 1.5943378
{'epoch': 4, 'step': 20, 'loss': 0.3344174921512604, 'time': 15}
val_loss 2.521074
{'epoch': 5, 'step': 25, 'loss': 0.24699482321739197, 'time': 18}
val_loss 3.5379293
{'epoch': 6, 'step': 30, 'loss': 0.2328643649816513, 'time': 21}
val_loss 4.9382687
{'epoch': 7, 'step': 35, 'loss': 0.242182657122612, 'time': 24}
val_loss 3.5618892
{'epoch': 8, 'step': 40, 'loss': 0.29148513078689575, 'time': 27}
val_loss 3.433142
{'epoch': 9, 'step': 45, 'loss': 0.24071571230888367, 'time': 30}
val_loss 4.975531
{'epoch': 10, 'step': 50, 'loss': 0.197666734457016, 'time': 33}
val_loss 4.1634192
{'epoch': 11, 'step': 55, 'loss': 0.14531727135181427, 'time': 36}
val_loss 1.30

{'epoch': 94, 'step': 470, 'loss': 0.04826023429632187, 'time': 292}
val_loss 0.4884953
{'epoch': 95, 'step': 475, 'loss': 0.02017037943005562, 'time': 295}
val_loss 1.2929963
{'epoch': 96, 'step': 480, 'loss': 0.05160146206617355, 'time': 298}
val_loss 0.4951358
{'epoch': 97, 'step': 485, 'loss': 0.05379893630743027, 'time': 301}
val_loss 0.96505296
{'epoch': 98, 'step': 490, 'loss': 0.017524488270282745, 'time': 304}
val_loss 0.57834107
