# 作业一：Covid-19新冠预测

给出美国某州过去5天内的调查结果，然后预测第5天的新检测阳性病例的百分比。

In [1]:
# 导入必要的包
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split

import math
import numpy as np
import pandas as pd
from tqdm import tqdm
import os
from torch.utils.tensorboard import SummaryWriter

In [2]:
# 下载数据
if not os.path.exists('covid.train.csv'):
    os.system("wget 'https://drive.google.com/uc?id=1kLSW_-cW2Huj7bh84YTdimGBOJaODiOS' -O covid.train.csv")
if not os.path.exists('covid.test.csv'):
    os.system("wget 'https://drive.google.com/uc?id=1iiI5qROrAhZn-o4FPqsE97bMzDEFvIdg' -O covid.test.csv")

### 示例代码提供的功能函数

In [3]:
def same_seed(seed): 
    '''Fixes random number generator seeds for reproducibility.'''
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def train_valid_split(data_set, valid_ratio, seed):
    '''Split provided training data into training set and validation set'''
    valid_set_size = int(valid_ratio * len(data_set)) 
    train_set_size = len(data_set) - valid_set_size
    train_set, valid_set = random_split(data_set, [train_set_size, valid_set_size], generator=torch.Generator().manual_seed(seed))
    return np.array(train_set), np.array(valid_set)

def predict(test_loader, model, device):
    model.eval() # Set your model to evaluation mode.
    preds = []
    for x in tqdm(test_loader):
        x = x.to(device)                        
        with torch.no_grad():                   
            pred = model(x)                     
            preds.append(pred.detach().cpu())   
    preds = torch.cat(preds, dim=0).numpy()  
    return preds

### 定义数据集类

In [4]:
class COVID19Dataset(Dataset):
    '''
    x: Features.
    y: Targets, if none, do prediction.
    '''
    def __init__(self, x, y=None):
        if y is None:
            self.y = y
        else:
            self.y = torch.FloatTensor(y)
        self.x = torch.FloatTensor(x)

    def __getitem__(self, idx):
        if self.y is None:
            return self.x[idx]
        else:
            return self.x[idx], self.y[idx]

    def __len__(self):
        return len(self.x)

### 神经网络定义

In [5]:
class CovidNet(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_features=input_dim, out_features=16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1)
        )

    def forward(self, x):
        x = self.layers(x)
        x = x.squeeze(1) # (B, 1) -> (B)
        return x

### 特征挑选
在训练集所给出的新冠数据各类特征中，挑选出我们所认为适合此次回归任务的

In [6]:
def select_feat(train_data, valid_data, test_data, select_all=True):
    '''Selects useful features to perform regression'''
    y_train, y_valid = train_data[:,-1], valid_data[:,-1]
    raw_x_train, raw_x_valid, raw_x_test = train_data[:,:-1], valid_data[:,:-1], test_data

    if select_all:
        feat_idx = list(range(raw_x_train.shape[1]))
    else:
        feat_idx = [0,1,2,3,4] # TODO: Select suitable feature columns.
        
    return raw_x_train[:,feat_idx], raw_x_valid[:,feat_idx], raw_x_test[:,feat_idx], y_train, y_valid

### 训练循环

In [7]:
def trainer(train_loader, valid_loader, model, config, device):

    criterion = nn.MSELoss(reduction='mean') # Define your loss function, do not modify this.

    # Define your optimization algorithm. 
    # TODO: Please check https://pytorch.org/docs/stable/optim.html to get more available algorithms.
    # TODO: L2 regularization (optimizer(weight decay...) or implement by your self).
    optimizer = torch.optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=0.9) 

    writer = SummaryWriter() # Writer of tensoboard.

    if not os.path.isdir('./models'):
        os.mkdir('./models') # Create directory of saving models.

    n_epochs, best_loss, step, early_stop_count = config['n_epochs'], math.inf, 0, 0

    for epoch in range(n_epochs):
        model.train() # Set your model to train mode.
        loss_record = []

        # tqdm is a package to visualize your training progress.
        train_pbar = tqdm(train_loader, position=0, leave=True)

        for x, y in train_pbar:
            optimizer.zero_grad()               # Set gradient to zero.
            x, y = x.to(device), y.to(device)   # Move your data to device. 
            pred = model(x)             
            loss = criterion(pred, y)
            loss.backward()                     # Compute gradient(backpropagation).
            optimizer.step()                    # Update parameters.
            step += 1
            loss_record.append(loss.detach().item())
            
            # Display current epoch number and loss on tqdm progress bar.
            train_pbar.set_description(f'Epoch [{epoch+1}/{n_epochs}]')
            train_pbar.set_postfix({'loss': loss.detach().item()})

        mean_train_loss = sum(loss_record)/len(loss_record)
        writer.add_scalar('Loss/train', mean_train_loss, step)

        model.eval() # Set your model to evaluation mode.
        loss_record = []
        for x, y in valid_loader:
            x, y = x.to(device), y.to(device)
            with torch.no_grad():
                pred = model(x)
                loss = criterion(pred, y)

            loss_record.append(loss.item())
            
        mean_valid_loss = sum(loss_record)/len(loss_record)
        print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')
        writer.add_scalar('Loss/valid', mean_valid_loss, step)

        if mean_valid_loss < best_loss:
            best_loss = mean_valid_loss
            torch.save(model.state_dict(), config['save_path']) # Save your best model
            print('Saving model with loss {:.3f}...'.format(best_loss))
            early_stop_count = 0
        else: 
            early_stop_count += 1

        if early_stop_count >= config['early_stop']:
            print('\nModel is not improving, so we halt the training session.')
            return

### 定义配置

In [8]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
config = {
    'seed': 666,          # Your seed number, you can pick your lucky number. :)
    'select_all': True,   # Whether to use all features.
    'valid_ratio': 0.2,   # validation_size = train_size * valid_ratio
    'n_epochs': 200,     # Number of epochs.            
    'batch_size': 256, 
    'learning_rate': 1e-5,              
    'early_stop': 400,    # If model has not improved for this many consecutive epochs, stop training.     
    'save_path': './models/model.ckpt'  # Your model will be saved here.
}

### 加载数据集

In [9]:
# Set seed for reproducibility
same_seed(config['seed'])


# train_data size: 2699 x 118 (id + 37 states + 16 features x 5 days) 
# test_data size: 1078 x 117 (without last day's positive rate)
train_data, test_data = pd.read_csv('./covid.train.csv').values, pd.read_csv('./covid.test.csv').values
train_data, valid_data = train_valid_split(train_data, config['valid_ratio'], config['seed'])

# Print out the data size.
print(f"""train_data size: {train_data.shape} 
valid_data size: {valid_data.shape} 
test_data size: {test_data.shape}""")

# Select features
x_train, x_valid, x_test, y_train, y_valid = select_feat(train_data, valid_data, test_data, config['select_all'])

# Print out the number of features.
print(f'number of features: {x_train.shape[1]}')

train_dataset, valid_dataset, test_dataset = COVID19Dataset(x_train, y_train), \
                                            COVID19Dataset(x_valid, y_valid), \
                                            COVID19Dataset(x_test)

# Pytorch data loader loads pytorch dataset into batches.
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=config['batch_size'], shuffle=False, pin_memory=True)

train_data size: (2160, 118) 
valid_data size: (539, 118) 
test_data size: (1078, 117)
number of features: 117


### 开始训练

In [10]:
model = CovidNet(input_dim=x_train.shape[1]).to(device)
trainer(train_loader, valid_loader, model, config, device)

Epoch [1/200]: 100%|██████████| 9/9 [00:00<00:00, 20.15it/s, loss=57.8]


Epoch [1/200]: Train loss: 90.8912, Valid loss: 42.9727
Saving model with loss 42.973...


Epoch [2/200]: 100%|██████████| 9/9 [00:00<00:00, 196.54it/s, loss=54]


Epoch [2/200]: Train loss: 51.5021, Valid loss: 48.1348


Epoch [3/200]: 100%|██████████| 9/9 [00:00<00:00, 184.08it/s, loss=54.6]


Epoch [3/200]: Train loss: 46.6594, Valid loss: 41.5541
Saving model with loss 41.554...


Epoch [4/200]: 100%|██████████| 9/9 [00:00<00:00, 183.91it/s, loss=50.1]


Epoch [4/200]: Train loss: 44.6555, Valid loss: 43.0886


Epoch [5/200]: 100%|██████████| 9/9 [00:00<00:00, 182.60it/s, loss=45.9]


Epoch [5/200]: Train loss: 42.1789, Valid loss: 41.3225
Saving model with loss 41.322...


Epoch [6/200]: 100%|██████████| 9/9 [00:00<00:00, 186.87it/s, loss=48.8]


Epoch [6/200]: Train loss: 40.6075, Valid loss: 33.7650
Saving model with loss 33.765...


Epoch [7/200]: 100%|██████████| 9/9 [00:00<00:00, 189.22it/s, loss=37.4]


Epoch [7/200]: Train loss: 37.9492, Valid loss: 35.5147


Epoch [8/200]: 100%|██████████| 9/9 [00:00<00:00, 182.30it/s, loss=43.7]


Epoch [8/200]: Train loss: 36.5279, Valid loss: 41.6572


Epoch [9/200]: 100%|██████████| 9/9 [00:00<00:00, 184.12it/s, loss=31.4]


Epoch [9/200]: Train loss: 34.6344, Valid loss: 35.0068


Epoch [10/200]: 100%|██████████| 9/9 [00:00<00:00, 184.28it/s, loss=25]


Epoch [10/200]: Train loss: 33.2061, Valid loss: 31.5049
Saving model with loss 31.505...


Epoch [11/200]: 100%|██████████| 9/9 [00:00<00:00, 187.16it/s, loss=37.3]


Epoch [11/200]: Train loss: 33.3295, Valid loss: 39.3667


Epoch [12/200]: 100%|██████████| 9/9 [00:00<00:00, 179.44it/s, loss=32.8]


Epoch [12/200]: Train loss: 33.7731, Valid loss: 31.0151
Saving model with loss 31.015...


Epoch [13/200]: 100%|██████████| 9/9 [00:00<00:00, 185.37it/s, loss=27.6]


Epoch [13/200]: Train loss: 31.2540, Valid loss: 30.3379
Saving model with loss 30.338...


Epoch [14/200]: 100%|██████████| 9/9 [00:00<00:00, 186.41it/s, loss=29.1]


Epoch [14/200]: Train loss: 30.4717, Valid loss: 31.6926


Epoch [15/200]: 100%|██████████| 9/9 [00:00<00:00, 181.92it/s, loss=32.7]


Epoch [15/200]: Train loss: 29.4868, Valid loss: 31.6856


Epoch [16/200]: 100%|██████████| 9/9 [00:00<00:00, 186.40it/s, loss=29.3]


Epoch [16/200]: Train loss: 28.5162, Valid loss: 27.1612
Saving model with loss 27.161...


Epoch [17/200]: 100%|██████████| 9/9 [00:00<00:00, 185.86it/s, loss=35.8]


Epoch [17/200]: Train loss: 28.9836, Valid loss: 36.3498


Epoch [18/200]: 100%|██████████| 9/9 [00:00<00:00, 179.75it/s, loss=30.3]


Epoch [18/200]: Train loss: 30.7669, Valid loss: 33.0369


Epoch [19/200]: 100%|██████████| 9/9 [00:00<00:00, 169.29it/s, loss=31]


Epoch [19/200]: Train loss: 34.0429, Valid loss: 33.5774


Epoch [20/200]: 100%|██████████| 9/9 [00:00<00:00, 181.49it/s, loss=27]


Epoch [20/200]: Train loss: 28.6513, Valid loss: 22.1485
Saving model with loss 22.149...


Epoch [21/200]: 100%|██████████| 9/9 [00:00<00:00, 75.90it/s, loss=23.7]


Epoch [21/200]: Train loss: 23.2636, Valid loss: 24.0064


Epoch [22/200]: 100%|██████████| 9/9 [00:00<00:00, 187.57it/s, loss=15.3]


Epoch [22/200]: Train loss: 21.3253, Valid loss: 23.5690


Epoch [23/200]: 100%|██████████| 9/9 [00:00<00:00, 163.65it/s, loss=23.5]


Epoch [23/200]: Train loss: 20.5684, Valid loss: 15.9413
Saving model with loss 15.941...


Epoch [24/200]: 100%|██████████| 9/9 [00:00<00:00, 185.03it/s, loss=15.1]


Epoch [24/200]: Train loss: 18.2925, Valid loss: 18.1259


Epoch [25/200]: 100%|██████████| 9/9 [00:00<00:00, 179.79it/s, loss=12.4]


Epoch [25/200]: Train loss: 14.3623, Valid loss: 13.4190
Saving model with loss 13.419...


Epoch [26/200]: 100%|██████████| 9/9 [00:00<00:00, 188.05it/s, loss=23.2]


Epoch [26/200]: Train loss: 17.6693, Valid loss: 40.4270


Epoch [27/200]: 100%|██████████| 9/9 [00:00<00:00, 178.75it/s, loss=48]


Epoch [27/200]: Train loss: 43.2811, Valid loss: 76.6484


Epoch [28/200]: 100%|██████████| 9/9 [00:00<00:00, 179.40it/s, loss=22.9]


Epoch [28/200]: Train loss: 37.5781, Valid loss: 32.7658


Epoch [29/200]: 100%|██████████| 9/9 [00:00<00:00, 166.72it/s, loss=24]


Epoch [29/200]: Train loss: 22.6961, Valid loss: 15.3990


Epoch [30/200]: 100%|██████████| 9/9 [00:00<00:00, 181.22it/s, loss=20.9]


Epoch [30/200]: Train loss: 17.6019, Valid loss: 15.5376


Epoch [31/200]: 100%|██████████| 9/9 [00:00<00:00, 180.72it/s, loss=11]


Epoch [31/200]: Train loss: 13.9720, Valid loss: 16.2149


Epoch [32/200]: 100%|██████████| 9/9 [00:00<00:00, 187.29it/s, loss=15.1]


Epoch [32/200]: Train loss: 12.1482, Valid loss: 11.6349
Saving model with loss 11.635...


Epoch [33/200]: 100%|██████████| 9/9 [00:00<00:00, 179.83it/s, loss=8.92]


Epoch [33/200]: Train loss: 9.5982, Valid loss: 9.3556
Saving model with loss 9.356...


Epoch [34/200]: 100%|██████████| 9/9 [00:00<00:00, 177.71it/s, loss=8.08]


Epoch [34/200]: Train loss: 9.3716, Valid loss: 8.3948
Saving model with loss 8.395...


Epoch [35/200]: 100%|██████████| 9/9 [00:00<00:00, 177.75it/s, loss=7.95]


Epoch [35/200]: Train loss: 7.3325, Valid loss: 9.3536


Epoch [36/200]: 100%|██████████| 9/9 [00:00<00:00, 177.04it/s, loss=7.28]


Epoch [36/200]: Train loss: 12.1056, Valid loss: 13.6680


Epoch [37/200]: 100%|██████████| 9/9 [00:00<00:00, 187.26it/s, loss=15.3]


Epoch [37/200]: Train loss: 20.0213, Valid loss: 11.9654


Epoch [38/200]: 100%|██████████| 9/9 [00:00<00:00, 181.36it/s, loss=6.77]


Epoch [38/200]: Train loss: 9.1326, Valid loss: 7.9322
Saving model with loss 7.932...


Epoch [39/200]: 100%|██████████| 9/9 [00:00<00:00, 180.62it/s, loss=9.44]


Epoch [39/200]: Train loss: 8.6231, Valid loss: 12.0466


Epoch [40/200]: 100%|██████████| 9/9 [00:00<00:00, 180.41it/s, loss=17.3]


Epoch [40/200]: Train loss: 12.5675, Valid loss: 6.4564
Saving model with loss 6.456...


Epoch [41/200]: 100%|██████████| 9/9 [00:00<00:00, 182.54it/s, loss=6.23]


Epoch [41/200]: Train loss: 10.0549, Valid loss: 11.4757


Epoch [42/200]: 100%|██████████| 9/9 [00:00<00:00, 163.40it/s, loss=8.03]


Epoch [42/200]: Train loss: 7.6487, Valid loss: 6.0605
Saving model with loss 6.060...


Epoch [43/200]: 100%|██████████| 9/9 [00:00<00:00, 186.35it/s, loss=9.37]


Epoch [43/200]: Train loss: 8.9507, Valid loss: 8.6006


Epoch [44/200]: 100%|██████████| 9/9 [00:00<00:00, 173.96it/s, loss=5.46]


Epoch [44/200]: Train loss: 7.5376, Valid loss: 6.3180


Epoch [45/200]: 100%|██████████| 9/9 [00:00<00:00, 186.97it/s, loss=4.13]


Epoch [45/200]: Train loss: 6.1561, Valid loss: 5.6846
Saving model with loss 5.685...


Epoch [46/200]: 100%|██████████| 9/9 [00:00<00:00, 187.02it/s, loss=8.13]


Epoch [46/200]: Train loss: 9.4218, Valid loss: 6.4995


Epoch [47/200]: 100%|██████████| 9/9 [00:00<00:00, 188.29it/s, loss=7.62]


Epoch [47/200]: Train loss: 8.5027, Valid loss: 6.9824


Epoch [48/200]: 100%|██████████| 9/9 [00:00<00:00, 191.74it/s, loss=5.93]


Epoch [48/200]: Train loss: 6.7294, Valid loss: 5.7230


Epoch [49/200]: 100%|██████████| 9/9 [00:00<00:00, 187.42it/s, loss=6]


Epoch [49/200]: Train loss: 8.7919, Valid loss: 7.0861


Epoch [50/200]: 100%|██████████| 9/9 [00:00<00:00, 185.03it/s, loss=6.77]


Epoch [50/200]: Train loss: 7.1994, Valid loss: 6.2361


Epoch [51/200]: 100%|██████████| 9/9 [00:00<00:00, 185.49it/s, loss=5.52]


Epoch [51/200]: Train loss: 6.3933, Valid loss: 5.9385


Epoch [52/200]: 100%|██████████| 9/9 [00:00<00:00, 182.03it/s, loss=5.76]


Epoch [52/200]: Train loss: 7.5064, Valid loss: 14.6239


Epoch [53/200]: 100%|██████████| 9/9 [00:00<00:00, 185.71it/s, loss=14.8]


Epoch [53/200]: Train loss: 13.1034, Valid loss: 14.6827


Epoch [54/200]: 100%|██████████| 9/9 [00:00<00:00, 184.12it/s, loss=7.1]


Epoch [54/200]: Train loss: 8.4254, Valid loss: 8.2893


Epoch [55/200]: 100%|██████████| 9/9 [00:00<00:00, 173.54it/s, loss=4.45]


Epoch [55/200]: Train loss: 6.9425, Valid loss: 5.5355
Saving model with loss 5.535...


Epoch [56/200]: 100%|██████████| 9/9 [00:00<00:00, 186.95it/s, loss=5.14]


Epoch [56/200]: Train loss: 5.9354, Valid loss: 5.4711
Saving model with loss 5.471...


Epoch [57/200]: 100%|██████████| 9/9 [00:00<00:00, 182.43it/s, loss=6.68]


Epoch [57/200]: Train loss: 6.1085, Valid loss: 5.5613


Epoch [58/200]: 100%|██████████| 9/9 [00:00<00:00, 185.27it/s, loss=5.94]


Epoch [58/200]: Train loss: 5.7682, Valid loss: 7.1364


Epoch [59/200]: 100%|██████████| 9/9 [00:00<00:00, 183.94it/s, loss=9.23]


Epoch [59/200]: Train loss: 6.6036, Valid loss: 7.1748


Epoch [60/200]: 100%|██████████| 9/9 [00:00<00:00, 181.53it/s, loss=7.26]


Epoch [60/200]: Train loss: 7.2034, Valid loss: 5.7793


Epoch [61/200]: 100%|██████████| 9/9 [00:00<00:00, 149.79it/s, loss=6]


Epoch [61/200]: Train loss: 5.8540, Valid loss: 5.2371
Saving model with loss 5.237...


Epoch [62/200]: 100%|██████████| 9/9 [00:00<00:00, 187.95it/s, loss=5.38]


Epoch [62/200]: Train loss: 6.1021, Valid loss: 6.2008


Epoch [63/200]: 100%|██████████| 9/9 [00:00<00:00, 177.75it/s, loss=5.3]


Epoch [63/200]: Train loss: 5.8541, Valid loss: 6.5433


Epoch [64/200]: 100%|██████████| 9/9 [00:00<00:00, 176.41it/s, loss=11.9]


Epoch [64/200]: Train loss: 6.7745, Valid loss: 11.5272


Epoch [65/200]: 100%|██████████| 9/9 [00:00<00:00, 192.44it/s, loss=11.7]


Epoch [65/200]: Train loss: 10.1007, Valid loss: 7.4855


Epoch [66/200]: 100%|██████████| 9/9 [00:00<00:00, 186.00it/s, loss=8.97]


Epoch [66/200]: Train loss: 8.2536, Valid loss: 7.6029


Epoch [67/200]: 100%|██████████| 9/9 [00:00<00:00, 195.29it/s, loss=5.83]


Epoch [67/200]: Train loss: 7.3849, Valid loss: 8.6600


Epoch [68/200]: 100%|██████████| 9/9 [00:00<00:00, 181.20it/s, loss=6.96]


Epoch [68/200]: Train loss: 7.8029, Valid loss: 5.7626


Epoch [69/200]: 100%|██████████| 9/9 [00:00<00:00, 194.11it/s, loss=5.65]


Epoch [69/200]: Train loss: 5.6917, Valid loss: 6.2034


Epoch [70/200]: 100%|██████████| 9/9 [00:00<00:00, 189.20it/s, loss=6]


Epoch [70/200]: Train loss: 6.3607, Valid loss: 5.2785


Epoch [71/200]: 100%|██████████| 9/9 [00:00<00:00, 193.52it/s, loss=7.02]


Epoch [71/200]: Train loss: 6.1064, Valid loss: 6.9013


Epoch [72/200]: 100%|██████████| 9/9 [00:00<00:00, 187.57it/s, loss=6.01]


Epoch [72/200]: Train loss: 6.3203, Valid loss: 7.3631


Epoch [73/200]: 100%|██████████| 9/9 [00:00<00:00, 191.87it/s, loss=7.33]


Epoch [73/200]: Train loss: 7.6353, Valid loss: 5.2641


Epoch [74/200]: 100%|██████████| 9/9 [00:00<00:00, 187.32it/s, loss=6.49]


Epoch [74/200]: Train loss: 5.9012, Valid loss: 5.6098


Epoch [75/200]: 100%|██████████| 9/9 [00:00<00:00, 190.87it/s, loss=6.9]


Epoch [75/200]: Train loss: 6.3378, Valid loss: 5.0566
Saving model with loss 5.057...


Epoch [76/200]: 100%|██████████| 9/9 [00:00<00:00, 201.17it/s, loss=5.78]


Epoch [76/200]: Train loss: 5.4463, Valid loss: 6.3005


Epoch [77/200]: 100%|██████████| 9/9 [00:00<00:00, 190.14it/s, loss=5.32]


Epoch [77/200]: Train loss: 5.4082, Valid loss: 4.2388
Saving model with loss 4.239...


Epoch [78/200]: 100%|██████████| 9/9 [00:00<00:00, 186.78it/s, loss=4.36]


Epoch [78/200]: Train loss: 5.2172, Valid loss: 5.7439


Epoch [79/200]: 100%|██████████| 9/9 [00:00<00:00, 182.33it/s, loss=6.1]


Epoch [79/200]: Train loss: 5.9793, Valid loss: 6.5533


Epoch [80/200]: 100%|██████████| 9/9 [00:00<00:00, 184.31it/s, loss=4.11]


Epoch [80/200]: Train loss: 5.7887, Valid loss: 8.1961


Epoch [81/200]: 100%|██████████| 9/9 [00:00<00:00, 191.13it/s, loss=5.96]


Epoch [81/200]: Train loss: 7.1399, Valid loss: 8.0553


Epoch [82/200]: 100%|██████████| 9/9 [00:00<00:00, 192.57it/s, loss=8.33]


Epoch [82/200]: Train loss: 6.4125, Valid loss: 5.4036


Epoch [83/200]: 100%|██████████| 9/9 [00:00<00:00, 184.19it/s, loss=4.89]


Epoch [83/200]: Train loss: 5.7174, Valid loss: 5.5278


Epoch [84/200]: 100%|██████████| 9/9 [00:00<00:00, 187.33it/s, loss=4.43]


Epoch [84/200]: Train loss: 5.5566, Valid loss: 4.9177


Epoch [85/200]: 100%|██████████| 9/9 [00:00<00:00, 187.93it/s, loss=7.57]


Epoch [85/200]: Train loss: 6.4720, Valid loss: 6.6862


Epoch [86/200]: 100%|██████████| 9/9 [00:00<00:00, 185.78it/s, loss=8.29]


Epoch [86/200]: Train loss: 6.8409, Valid loss: 5.3506


Epoch [87/200]: 100%|██████████| 9/9 [00:00<00:00, 188.40it/s, loss=4.02]


Epoch [87/200]: Train loss: 5.2510, Valid loss: 4.4606


Epoch [88/200]: 100%|██████████| 9/9 [00:00<00:00, 188.09it/s, loss=4.75]


Epoch [88/200]: Train loss: 5.1321, Valid loss: 4.9130


Epoch [89/200]: 100%|██████████| 9/9 [00:00<00:00, 192.43it/s, loss=5.57]


Epoch [89/200]: Train loss: 5.2885, Valid loss: 6.9941


Epoch [90/200]: 100%|██████████| 9/9 [00:00<00:00, 183.35it/s, loss=5.17]


Epoch [90/200]: Train loss: 5.6923, Valid loss: 5.2233


Epoch [91/200]: 100%|██████████| 9/9 [00:00<00:00, 186.31it/s, loss=8.53]


Epoch [91/200]: Train loss: 6.4670, Valid loss: 4.7237


Epoch [92/200]: 100%|██████████| 9/9 [00:00<00:00, 185.66it/s, loss=6.83]


Epoch [92/200]: Train loss: 7.5978, Valid loss: 5.0043


Epoch [93/200]: 100%|██████████| 9/9 [00:00<00:00, 185.53it/s, loss=5.74]


Epoch [93/200]: Train loss: 6.2281, Valid loss: 4.6846


Epoch [94/200]: 100%|██████████| 9/9 [00:00<00:00, 177.66it/s, loss=5.95]


Epoch [94/200]: Train loss: 5.3522, Valid loss: 4.6424


Epoch [95/200]: 100%|██████████| 9/9 [00:00<00:00, 184.99it/s, loss=4.49]


Epoch [95/200]: Train loss: 5.3068, Valid loss: 5.6017


Epoch [96/200]: 100%|██████████| 9/9 [00:00<00:00, 162.13it/s, loss=5.26]


Epoch [96/200]: Train loss: 5.2067, Valid loss: 5.4938


Epoch [97/200]: 100%|██████████| 9/9 [00:00<00:00, 177.15it/s, loss=4.35]


Epoch [97/200]: Train loss: 6.0269, Valid loss: 5.3459


Epoch [98/200]: 100%|██████████| 9/9 [00:00<00:00, 182.37it/s, loss=5.98]


Epoch [98/200]: Train loss: 5.4755, Valid loss: 4.9246


Epoch [99/200]: 100%|██████████| 9/9 [00:00<00:00, 168.50it/s, loss=5.3]


Epoch [99/200]: Train loss: 5.1135, Valid loss: 4.8599


Epoch [100/200]: 100%|██████████| 9/9 [00:00<00:00, 182.90it/s, loss=6.7]


Epoch [100/200]: Train loss: 6.2776, Valid loss: 5.0706


Epoch [101/200]: 100%|██████████| 9/9 [00:00<00:00, 180.58it/s, loss=4.64]


Epoch [101/200]: Train loss: 5.0417, Valid loss: 5.1223


Epoch [102/200]: 100%|██████████| 9/9 [00:00<00:00, 183.47it/s, loss=4.68]


Epoch [102/200]: Train loss: 4.9271, Valid loss: 5.6541


Epoch [103/200]: 100%|██████████| 9/9 [00:00<00:00, 161.26it/s, loss=3.93]


Epoch [103/200]: Train loss: 5.5525, Valid loss: 6.3992


Epoch [104/200]: 100%|██████████| 9/9 [00:00<00:00, 184.10it/s, loss=5.69]


Epoch [104/200]: Train loss: 5.4554, Valid loss: 5.2199


Epoch [105/200]: 100%|██████████| 9/9 [00:00<00:00, 178.79it/s, loss=5.51]


Epoch [105/200]: Train loss: 5.0407, Valid loss: 4.9486


Epoch [106/200]: 100%|██████████| 9/9 [00:00<00:00, 184.85it/s, loss=3.59]


Epoch [106/200]: Train loss: 5.0234, Valid loss: 5.0175


Epoch [107/200]: 100%|██████████| 9/9 [00:00<00:00, 181.80it/s, loss=5.35]


Epoch [107/200]: Train loss: 5.0871, Valid loss: 7.7658


Epoch [108/200]: 100%|██████████| 9/9 [00:00<00:00, 83.19it/s, loss=7.74]


Epoch [108/200]: Train loss: 8.3919, Valid loss: 9.2871


Epoch [109/200]: 100%|██████████| 9/9 [00:00<00:00, 158.32it/s, loss=5.56]


Epoch [109/200]: Train loss: 8.0710, Valid loss: 10.8972


Epoch [110/200]: 100%|██████████| 9/9 [00:00<00:00, 180.00it/s, loss=4.77]


Epoch [110/200]: Train loss: 8.3463, Valid loss: 14.0247


Epoch [111/200]: 100%|██████████| 9/9 [00:00<00:00, 182.24it/s, loss=11]


Epoch [111/200]: Train loss: 10.0503, Valid loss: 5.8415


Epoch [112/200]: 100%|██████████| 9/9 [00:00<00:00, 167.42it/s, loss=9.23]


Epoch [112/200]: Train loss: 9.7575, Valid loss: 5.8183


Epoch [113/200]: 100%|██████████| 9/9 [00:00<00:00, 219.79it/s, loss=5.31]


Epoch [113/200]: Train loss: 7.1886, Valid loss: 7.5293


Epoch [114/200]: 100%|██████████| 9/9 [00:00<00:00, 194.90it/s, loss=4.83]


Epoch [114/200]: Train loss: 6.0099, Valid loss: 5.4968


Epoch [115/200]: 100%|██████████| 9/9 [00:00<00:00, 180.53it/s, loss=5.09]


Epoch [115/200]: Train loss: 5.1895, Valid loss: 5.2665


Epoch [116/200]: 100%|██████████| 9/9 [00:00<00:00, 183.85it/s, loss=5.62]


Epoch [116/200]: Train loss: 5.2466, Valid loss: 5.5444


Epoch [117/200]: 100%|██████████| 9/9 [00:00<00:00, 171.07it/s, loss=4.92]


Epoch [117/200]: Train loss: 5.8036, Valid loss: 4.9768


Epoch [118/200]: 100%|██████████| 9/9 [00:00<00:00, 186.87it/s, loss=4.49]


Epoch [118/200]: Train loss: 5.3632, Valid loss: 4.4713


Epoch [119/200]: 100%|██████████| 9/9 [00:00<00:00, 176.15it/s, loss=3.67]


Epoch [119/200]: Train loss: 4.8919, Valid loss: 4.4215


Epoch [120/200]: 100%|██████████| 9/9 [00:00<00:00, 206.51it/s, loss=4.81]


Epoch [120/200]: Train loss: 4.8076, Valid loss: 4.3074


Epoch [121/200]: 100%|██████████| 9/9 [00:00<00:00, 184.23it/s, loss=6.25]


Epoch [121/200]: Train loss: 5.2056, Valid loss: 5.6029


Epoch [122/200]: 100%|██████████| 9/9 [00:00<00:00, 182.66it/s, loss=5.28]


Epoch [122/200]: Train loss: 5.3120, Valid loss: 4.2100
Saving model with loss 4.210...


Epoch [123/200]: 100%|██████████| 9/9 [00:00<00:00, 185.54it/s, loss=4.18]


Epoch [123/200]: Train loss: 4.6861, Valid loss: 6.2786


Epoch [124/200]: 100%|██████████| 9/9 [00:00<00:00, 164.02it/s, loss=4.77]


Epoch [124/200]: Train loss: 5.1310, Valid loss: 4.7267


Epoch [125/200]: 100%|██████████| 9/9 [00:00<00:00, 181.38it/s, loss=6.35]


Epoch [125/200]: Train loss: 5.0356, Valid loss: 5.1984


Epoch [126/200]: 100%|██████████| 9/9 [00:00<00:00, 183.81it/s, loss=5.21]


Epoch [126/200]: Train loss: 5.3707, Valid loss: 5.0233


Epoch [127/200]: 100%|██████████| 9/9 [00:00<00:00, 183.18it/s, loss=4.68]


Epoch [127/200]: Train loss: 5.8478, Valid loss: 4.9518


Epoch [128/200]: 100%|██████████| 9/9 [00:00<00:00, 173.22it/s, loss=4.38]


Epoch [128/200]: Train loss: 4.8957, Valid loss: 4.5009


Epoch [129/200]: 100%|██████████| 9/9 [00:00<00:00, 200.45it/s, loss=6.27]


Epoch [129/200]: Train loss: 5.5598, Valid loss: 3.9569
Saving model with loss 3.957...


Epoch [130/200]: 100%|██████████| 9/9 [00:00<00:00, 194.80it/s, loss=4.27]


Epoch [130/200]: Train loss: 5.2986, Valid loss: 4.5898


Epoch [131/200]: 100%|██████████| 9/9 [00:00<00:00, 187.34it/s, loss=4.98]


Epoch [131/200]: Train loss: 5.1683, Valid loss: 5.2223


Epoch [132/200]: 100%|██████████| 9/9 [00:00<00:00, 185.38it/s, loss=4.21]


Epoch [132/200]: Train loss: 5.2699, Valid loss: 5.0010


Epoch [133/200]: 100%|██████████| 9/9 [00:00<00:00, 180.70it/s, loss=5.36]


Epoch [133/200]: Train loss: 6.1777, Valid loss: 5.9770


Epoch [134/200]: 100%|██████████| 9/9 [00:00<00:00, 188.42it/s, loss=5.26]


Epoch [134/200]: Train loss: 4.8981, Valid loss: 4.6091


Epoch [135/200]: 100%|██████████| 9/9 [00:00<00:00, 184.88it/s, loss=5.42]


Epoch [135/200]: Train loss: 4.7248, Valid loss: 3.8320
Saving model with loss 3.832...


Epoch [136/200]: 100%|██████████| 9/9 [00:00<00:00, 187.77it/s, loss=5.25]


Epoch [136/200]: Train loss: 4.6020, Valid loss: 4.4902


Epoch [137/200]: 100%|██████████| 9/9 [00:00<00:00, 181.08it/s, loss=4.47]


Epoch [137/200]: Train loss: 4.5604, Valid loss: 5.3892


Epoch [138/200]: 100%|██████████| 9/9 [00:00<00:00, 184.42it/s, loss=5.04]


Epoch [138/200]: Train loss: 4.7627, Valid loss: 4.4278


Epoch [139/200]: 100%|██████████| 9/9 [00:00<00:00, 180.78it/s, loss=3.37]


Epoch [139/200]: Train loss: 4.4681, Valid loss: 3.9350


Epoch [140/200]: 100%|██████████| 9/9 [00:00<00:00, 181.26it/s, loss=4.2]


Epoch [140/200]: Train loss: 4.6744, Valid loss: 4.2095


Epoch [141/200]: 100%|██████████| 9/9 [00:00<00:00, 175.86it/s, loss=3.89]


Epoch [141/200]: Train loss: 4.7010, Valid loss: 4.1448


Epoch [142/200]: 100%|██████████| 9/9 [00:00<00:00, 177.77it/s, loss=3.8]


Epoch [142/200]: Train loss: 4.7821, Valid loss: 3.6202
Saving model with loss 3.620...


Epoch [143/200]: 100%|██████████| 9/9 [00:00<00:00, 184.89it/s, loss=4.13]


Epoch [143/200]: Train loss: 4.4646, Valid loss: 5.3087


Epoch [144/200]: 100%|██████████| 9/9 [00:00<00:00, 181.24it/s, loss=4.37]


Epoch [144/200]: Train loss: 4.5322, Valid loss: 4.9229


Epoch [145/200]: 100%|██████████| 9/9 [00:00<00:00, 181.15it/s, loss=4.76]


Epoch [145/200]: Train loss: 4.6845, Valid loss: 3.8551


Epoch [146/200]: 100%|██████████| 9/9 [00:00<00:00, 184.51it/s, loss=5.88]


Epoch [146/200]: Train loss: 4.7192, Valid loss: 5.8063


Epoch [147/200]: 100%|██████████| 9/9 [00:00<00:00, 182.05it/s, loss=6.57]


Epoch [147/200]: Train loss: 5.9395, Valid loss: 5.7734


Epoch [148/200]: 100%|██████████| 9/9 [00:00<00:00, 172.86it/s, loss=5.58]


Epoch [148/200]: Train loss: 5.8728, Valid loss: 3.8159


Epoch [149/200]: 100%|██████████| 9/9 [00:00<00:00, 178.37it/s, loss=5.01]


Epoch [149/200]: Train loss: 4.7560, Valid loss: 5.0780


Epoch [150/200]: 100%|██████████| 9/9 [00:00<00:00, 181.93it/s, loss=6.13]


Epoch [150/200]: Train loss: 4.9796, Valid loss: 4.9410


Epoch [151/200]: 100%|██████████| 9/9 [00:00<00:00, 68.02it/s, loss=5.23]


Epoch [151/200]: Train loss: 4.8474, Valid loss: 3.9359


Epoch [152/200]: 100%|██████████| 9/9 [00:00<00:00, 185.18it/s, loss=4.5]


Epoch [152/200]: Train loss: 4.8710, Valid loss: 5.4057


Epoch [153/200]: 100%|██████████| 9/9 [00:00<00:00, 184.60it/s, loss=4.37]


Epoch [153/200]: Train loss: 4.6350, Valid loss: 4.1593


Epoch [154/200]: 100%|██████████| 9/9 [00:00<00:00, 178.35it/s, loss=4.45]


Epoch [154/200]: Train loss: 4.3598, Valid loss: 4.2380


Epoch [155/200]: 100%|██████████| 9/9 [00:00<00:00, 173.07it/s, loss=4.5]


Epoch [155/200]: Train loss: 4.7104, Valid loss: 4.8623


Epoch [156/200]: 100%|██████████| 9/9 [00:00<00:00, 185.48it/s, loss=5.83]


Epoch [156/200]: Train loss: 4.7081, Valid loss: 4.4942


Epoch [157/200]: 100%|██████████| 9/9 [00:00<00:00, 185.40it/s, loss=5.2]


Epoch [157/200]: Train loss: 4.5980, Valid loss: 5.8612


Epoch [158/200]: 100%|██████████| 9/9 [00:00<00:00, 182.58it/s, loss=3.69]


Epoch [158/200]: Train loss: 5.0380, Valid loss: 4.1801


Epoch [159/200]: 100%|██████████| 9/9 [00:00<00:00, 182.16it/s, loss=5.04]


Epoch [159/200]: Train loss: 4.4782, Valid loss: 4.5032


Epoch [160/200]: 100%|██████████| 9/9 [00:00<00:00, 190.21it/s, loss=3.3]


Epoch [160/200]: Train loss: 4.2352, Valid loss: 3.9444


Epoch [161/200]: 100%|██████████| 9/9 [00:00<00:00, 178.10it/s, loss=6.19]


Epoch [161/200]: Train loss: 4.6847, Valid loss: 4.1394


Epoch [162/200]: 100%|██████████| 9/9 [00:00<00:00, 183.00it/s, loss=4.75]


Epoch [162/200]: Train loss: 5.0282, Valid loss: 3.7840


Epoch [163/200]: 100%|██████████| 9/9 [00:00<00:00, 177.92it/s, loss=4.13]


Epoch [163/200]: Train loss: 4.2574, Valid loss: 7.2044


Epoch [164/200]: 100%|██████████| 9/9 [00:00<00:00, 183.55it/s, loss=5.15]


Epoch [164/200]: Train loss: 5.1784, Valid loss: 5.1153


Epoch [165/200]: 100%|██████████| 9/9 [00:00<00:00, 180.30it/s, loss=5.98]


Epoch [165/200]: Train loss: 5.0109, Valid loss: 4.3132


Epoch [166/200]: 100%|██████████| 9/9 [00:00<00:00, 185.02it/s, loss=7.51]


Epoch [166/200]: Train loss: 5.1919, Valid loss: 4.3147


Epoch [167/200]: 100%|██████████| 9/9 [00:00<00:00, 180.84it/s, loss=3.62]


Epoch [167/200]: Train loss: 4.3836, Valid loss: 4.0747


Epoch [168/200]: 100%|██████████| 9/9 [00:00<00:00, 181.75it/s, loss=7.02]


Epoch [168/200]: Train loss: 4.7486, Valid loss: 4.1079


Epoch [169/200]: 100%|██████████| 9/9 [00:00<00:00, 182.15it/s, loss=8.14]


Epoch [169/200]: Train loss: 5.8472, Valid loss: 5.1062


Epoch [170/200]: 100%|██████████| 9/9 [00:00<00:00, 179.48it/s, loss=6.8]


Epoch [170/200]: Train loss: 5.6181, Valid loss: 4.5614


Epoch [171/200]: 100%|██████████| 9/9 [00:00<00:00, 181.28it/s, loss=2.82]


Epoch [171/200]: Train loss: 4.2651, Valid loss: 3.7789


Epoch [172/200]: 100%|██████████| 9/9 [00:00<00:00, 182.07it/s, loss=3.92]


Epoch [172/200]: Train loss: 4.4193, Valid loss: 4.5371


Epoch [173/200]: 100%|██████████| 9/9 [00:00<00:00, 184.05it/s, loss=4.37]


Epoch [173/200]: Train loss: 4.9723, Valid loss: 5.3496


Epoch [174/200]: 100%|██████████| 9/9 [00:00<00:00, 187.73it/s, loss=4.25]


Epoch [174/200]: Train loss: 4.8014, Valid loss: 4.6865


Epoch [175/200]: 100%|██████████| 9/9 [00:00<00:00, 183.63it/s, loss=4.95]


Epoch [175/200]: Train loss: 4.4236, Valid loss: 3.6751


Epoch [176/200]: 100%|██████████| 9/9 [00:00<00:00, 181.96it/s, loss=4.38]


Epoch [176/200]: Train loss: 4.4521, Valid loss: 4.3559


Epoch [177/200]: 100%|██████████| 9/9 [00:00<00:00, 185.35it/s, loss=3.79]


Epoch [177/200]: Train loss: 4.3461, Valid loss: 4.1366


Epoch [178/200]: 100%|██████████| 9/9 [00:00<00:00, 181.43it/s, loss=5.31]


Epoch [178/200]: Train loss: 5.1385, Valid loss: 3.9281


Epoch [179/200]: 100%|██████████| 9/9 [00:00<00:00, 186.95it/s, loss=5.26]


Epoch [179/200]: Train loss: 5.2045, Valid loss: 4.4888


Epoch [180/200]: 100%|██████████| 9/9 [00:00<00:00, 182.49it/s, loss=4.88]


Epoch [180/200]: Train loss: 4.3203, Valid loss: 4.8721


Epoch [181/200]: 100%|██████████| 9/9 [00:00<00:00, 181.19it/s, loss=3.76]


Epoch [181/200]: Train loss: 4.4350, Valid loss: 4.1935


Epoch [182/200]: 100%|██████████| 9/9 [00:00<00:00, 186.02it/s, loss=4.52]


Epoch [182/200]: Train loss: 4.2170, Valid loss: 5.4579


Epoch [183/200]: 100%|██████████| 9/9 [00:00<00:00, 184.49it/s, loss=5]


Epoch [183/200]: Train loss: 4.8504, Valid loss: 6.1746


Epoch [184/200]: 100%|██████████| 9/9 [00:00<00:00, 184.55it/s, loss=4.02]


Epoch [184/200]: Train loss: 5.1061, Valid loss: 5.9704


Epoch [185/200]: 100%|██████████| 9/9 [00:00<00:00, 183.68it/s, loss=4.12]


Epoch [185/200]: Train loss: 4.5680, Valid loss: 4.3305


Epoch [186/200]: 100%|██████████| 9/9 [00:00<00:00, 180.29it/s, loss=4.63]


Epoch [186/200]: Train loss: 5.1463, Valid loss: 8.1404


Epoch [187/200]: 100%|██████████| 9/9 [00:00<00:00, 181.21it/s, loss=3.74]


Epoch [187/200]: Train loss: 5.8851, Valid loss: 5.3107


Epoch [188/200]: 100%|██████████| 9/9 [00:00<00:00, 184.74it/s, loss=5.63]


Epoch [188/200]: Train loss: 5.0427, Valid loss: 5.4205


Epoch [189/200]: 100%|██████████| 9/9 [00:00<00:00, 183.42it/s, loss=4.21]


Epoch [189/200]: Train loss: 4.8152, Valid loss: 4.2609


Epoch [190/200]: 100%|██████████| 9/9 [00:00<00:00, 187.35it/s, loss=6.16]


Epoch [190/200]: Train loss: 5.2363, Valid loss: 6.7699


Epoch [191/200]: 100%|██████████| 9/9 [00:00<00:00, 183.38it/s, loss=6.64]


Epoch [191/200]: Train loss: 5.8780, Valid loss: 5.0492


Epoch [192/200]: 100%|██████████| 9/9 [00:00<00:00, 182.54it/s, loss=9.61]


Epoch [192/200]: Train loss: 6.0770, Valid loss: 3.9648


Epoch [193/200]: 100%|██████████| 9/9 [00:00<00:00, 178.95it/s, loss=5.44]


Epoch [193/200]: Train loss: 5.3960, Valid loss: 8.3318


Epoch [194/200]: 100%|██████████| 9/9 [00:00<00:00, 60.71it/s, loss=4.68]


Epoch [194/200]: Train loss: 5.7485, Valid loss: 4.4062


Epoch [195/200]: 100%|██████████| 9/9 [00:00<00:00, 181.64it/s, loss=4.02]


Epoch [195/200]: Train loss: 4.5081, Valid loss: 3.7487


Epoch [196/200]: 100%|██████████| 9/9 [00:00<00:00, 181.12it/s, loss=5.78]


Epoch [196/200]: Train loss: 4.5727, Valid loss: 4.2078


Epoch [197/200]: 100%|██████████| 9/9 [00:00<00:00, 181.14it/s, loss=5.91]


Epoch [197/200]: Train loss: 4.7014, Valid loss: 3.8889


Epoch [198/200]: 100%|██████████| 9/9 [00:00<00:00, 177.47it/s, loss=3.85]


Epoch [198/200]: Train loss: 4.0426, Valid loss: 3.6702


Epoch [199/200]: 100%|██████████| 9/9 [00:00<00:00, 182.50it/s, loss=4.39]


Epoch [199/200]: Train loss: 4.0385, Valid loss: 4.0515


Epoch [200/200]: 100%|██████████| 9/9 [00:00<00:00, 184.43it/s, loss=4.61]


Epoch [200/200]: Train loss: 4.4174, Valid loss: 4.5749


### Tensorboard查看曲线

In [11]:
%reload_ext tensorboard
%tensorboard --logdir=./runs/

### 测试

In [12]:
import csv

def save_pred(preds, file):
    ''' Save predictions to specified file '''
    with open(file, 'w') as fp:
        writer = csv.writer(fp)
        writer.writerow(['id', 'tested_positive'])
        for i, p in enumerate(preds):
            writer.writerow([i, p])

model = CovidNet(input_dim=x_train.shape[1]).to(device)
model.load_state_dict(torch.load(config['save_path']))
preds = predict(test_loader, model, device) 
save_pred(preds, 'pred.csv') 

100%|██████████| 5/5 [00:00<00:00, 855.67it/s]
