In [1]:
import torch
import torch.nn as nn
import torch.functional as F 
from tqdm import tqdm
from model.mscred import MSCRED
from utils.data import load_data
import matplotlib.pyplot as plt
import numpy as np
import os

In [2]:
def train(dataLoader, model, optimizer, epochs, device):
    model = model.to(device)
    print("------training on {}-------".format(device))
    for epoch in range(epochs):
        train_l_sum,n = 0.0, 0
        for x in tqdm(dataLoader):
            x = x.to(device)
            x = x.squeeze()
            #print(type(x))

            '''
            torch.autograd.set_detect_anomaly(True)
            model_output = model(x)
            l = torch.mean((model_output-x[-1].unsqueeze(0))**2)
            '''
            
            l = torch.mean((model(x)-x[-1].unsqueeze(0))**2)
            train_l_sum += l
            optimizer.zero_grad()
            '''
            with torch.autograd.detect_anomaly():
                l.backward()
            '''
            l.backward()
            optimizer.step()
            n += 1
            #print("[Epoch %d/%d][Batch %d/%d] [loss: %f]" % (epoch+1, epochs, n, len(dataLoader), l.item()))
            
        print("[Epoch %d/%d] [loss: %f]" % (epoch+1, epochs, train_l_sum/n))

def test(dataLoader, model):
    print("------Testing-------")
    index = 800
    loss_list = []
    reconstructed_data_path = "./utils/data/matrix_data/reconstructed_data/"
    
    if not os.path.exists(reconstructed_data_path):
        os.makedirs(reconstructed_data_path)
    
    with torch.no_grad():
        for x in dataLoader:
            x = x.to(device)
            x = x.squeeze()
            reconstructed_matrix = model(x) 
            path_temp = os.path.join(reconstructed_data_path, 'reconstructed_data_' + str(index) + ".npy")
            np.save(path_temp, reconstructed_matrix.cpu().detach().numpy())
            # l = criterion(reconstructed_matrix, x[-1].unsqueeze(0)).mean()
            # loss_list.append(l)
            # print("[test_index %d] [loss: %f]" % (index, l.item()))
            index += 1

In [3]:
if __name__ == '__main__':
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("device is", device)
    dataLoader = load_data()
    mscred = MSCRED(3, 256)

    # 训练阶段
    # mscred.load_state_dict(torch.load("./checkpoints/model1.pth"))
    optimizer = torch.optim.Adam(mscred.parameters(), lr = 0.0002)
    train(dataLoader["train"], mscred, optimizer, 10, device)

device is cuda
------training on cuda-------


  return self._call_impl(*args, **kwargs)
100%|█████████████████████████████████████████████████████████████████████████████████| 789/789 [00:39<00:00, 19.86it/s]


[Epoch 1/10] [loss: 0.001739]


100%|█████████████████████████████████████████████████████████████████████████████████| 789/789 [00:39<00:00, 20.05it/s]


[Epoch 2/10] [loss: 0.000159]


100%|█████████████████████████████████████████████████████████████████████████████████| 789/789 [00:41<00:00, 19.21it/s]


[Epoch 3/10] [loss: 0.000087]


100%|█████████████████████████████████████████████████████████████████████████████████| 789/789 [00:43<00:00, 18.22it/s]


[Epoch 4/10] [loss: 0.000066]


100%|█████████████████████████████████████████████████████████████████████████████████| 789/789 [00:38<00:00, 20.48it/s]


[Epoch 5/10] [loss: 0.000046]


100%|█████████████████████████████████████████████████████████████████████████████████| 789/789 [00:37<00:00, 20.90it/s]


[Epoch 6/10] [loss: 0.000041]


100%|█████████████████████████████████████████████████████████████████████████████████| 789/789 [00:39<00:00, 20.14it/s]


[Epoch 7/10] [loss: 0.000034]


100%|█████████████████████████████████████████████████████████████████████████████████| 789/789 [00:37<00:00, 21.09it/s]


[Epoch 8/10] [loss: 0.000032]


100%|█████████████████████████████████████████████████████████████████████████████████| 789/789 [00:37<00:00, 21.30it/s]


[Epoch 9/10] [loss: 0.000025]


100%|█████████████████████████████████████████████████████████████████████████████████| 789/789 [00:38<00:00, 20.48it/s]


[Epoch 10/10] [loss: 0.000026]


In [4]:
    print("保存模型中....")
    #torch.save(mscred.state_dict(), "./checkpoints/model2.pth")

保存模型中....


In [5]:
    # # 测试阶段
    mscred.load_state_dict(torch.load("./checkpoints/model2.pth"))
    mscred.to(device)
    test(dataLoader["test"], mscred)

------Testing-------


In [6]:
    sum = 0
    #criterion = torch.nn.MSELoss()
    criterion = torch.sub
    
    print("------Testing-------")
    index = 800
    loss_list = []
    reconstructed_data_path = "./utils/data/matrix_data/reconstructed_data/"
    
    if not os.path.exists(reconstructed_data_path):
        os.makedirs(reconstructed_data_path)
    
    with torch.no_grad():
        for x in dataLoader['test']:
            x = x.to(device)
            x = x.squeeze()
            reconstructed_matrix = mscred(x) 
            #path_temp = os.path.join(reconstructed_data_path, 'reconstructed_data_' + str(index) + ".npy")
            #np.save(path_temp, reconstructed_matrix.cpu().detach().numpy())
            l = criterion(reconstructed_matrix, x[-1].unsqueeze(0)).mean()
            sum += l
            loss_list.append(l)
            #print("[test_index %d] [loss: %f]" % (index, l.item()))
            index += 1

------Testing-------


In [7]:
mean = sum / len(loss_list)

In [8]:
mean

tensor(-0.0008, device='cuda:0')

In [9]:
torch.stack(loss_list[:10])[:]

tensor([-0.0097, -0.0101, -0.0073, -0.0056, -0.0042, -0.0012, -0.0028, -0.0005,
         0.0026,  0.0027], device='cuda:0')

In [10]:
std = torch.std(torch.stack(loss_list[:]))

In [11]:
std

tensor(0.0101, device='cuda:0')

In [12]:
mean + std * 2

tensor(0.0194, device='cuda:0')