## Below contains necessary APIs for this homework
### _Please note that all prequisite libraries can be found in requirements.txt_

In [None]:
import numpy as np
import logging
import sklearn


import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
from torch.utils.data import Dataset


# below contains necessary APIs for training.
def create_logger(log_path):
    """
    将日志输出到日志文件和控制台
    """
    x = logging.getLogger(__name__)
    x.setLevel(logging.INFO)

    formatter = logging.Formatter(
        '%(asctime)s - %(levelname)s - %(message)s')

    # 创建一个handler，用于写入日志文件
    file_handler = logging.FileHandler(
        filename=log_path)
    file_handler.setFormatter(formatter)
    file_handler.setLevel(logging.INFO)
    x.addHandler(file_handler)

    # 创建一个handler，用于将日志输出到控制台
    console = logging.StreamHandler()
    console.setLevel(logging.DEBUG)
    console.setFormatter(formatter)
    x.addHandler(console)
    return x


def set_seed(digit: int):
    assert digit, 'Please specify a non-zero number when calling this func.'
    torch.manual_seed(digit)
    np.random.seed(digit)


def accuracy(x: np.ndarray, y: np.ndarray):
    # calculate accuracy for each class and macro F1-score.
    # designed for this task only.
    # return correct number of samples corresponds to each class and the number of samples for each class.
    # for the overall mean accuracy, just sum them all and divide.
    index_0 = np.where(y == 0)[0]
    index_1 = np.where(y == 1)[0]
    index_2 = np.where(y == 2)[0]
    x_0, y_0 = x[index_0], y[index_0]
    x_1, y_1 = x[index_1], y[index_1]
    x_2, y_2 = x[index_2], y[index_2]
    acc_0 = 0 if not len(y_0) or not len(x_0) else accuracy_score(y_0, x_0[..., 0], normalize=False)
    acc_1 = 0 if not len(y_1) or not len(x_1) else accuracy_score(y_1, x_1[..., 0], normalize=False)
    acc_2 = 0 if not len(y_2) or not len(x_2) else accuracy_score(y_2, x_2[..., 0], normalize=False)
    return [acc_0, acc_1, acc_2], [len(index_0), len(index_1), len(index_2)]


def eval_net(model, loader, device):
    model.eval()
    n_val = len(loader)  # the number of batch
    total_acc_0 = 0
    len_0 = 0
    total_acc_1 = 0
    len_1 = 0
    total_acc_2 = 0
    len_2 = 0

    with tqdm(total=n_val, desc='Evaluation round', unit='batch', leave=False) as pbar:
        for bt in loader:
            xs, ys = bt['feature'], bt['label']
            if isinstance(model, GRU):
                xs = xs.long()
            xs = xs.to(device=device)
            ys = ys.to(device=device)

            with torch.no_grad():
                if isinstance(model, ANN) or isinstance(model, GRU):
                    xs = torch.squeeze(xs)
                refs = model(xs)
                refs = F.softmax(refs, dim=1)
                refs = torch.argmax(refs, dim=1, keepdim=True)
            accuracies, lens = accuracy(refs.detach().cpu().numpy(), ys.detach().cpu().numpy())
            total_acc_0 += accuracies[0]
            len_0 += lens[0]
            total_acc_1 += accuracies[1]
            len_1 += lens[1]
            total_acc_2 += accuracies[2]
            len_2 += lens[2]
            pbar.update()

    model.train()
    return total_acc_0 / len_0, total_acc_1 / len_1, total_acc_2 / len_2


# below is our dataset function.
class SequenceDataset(Dataset):
    def __init__(self, x: str, shuffle: bool = True, seed: int = 1, balanced=True, delta_1=1, delta_2=1):
        super(SequenceDataset, self).__init__()
        raw_data = np.load(x)
        self.features = raw_data['reads'].astype(np.float32)
        try:
            self.labels = raw_data['label']
        except KeyError:
            self.labels = np.ones(self.features.shape[0]) * -1
        self.labels = self.labels.astype(np.int64)
        self.ids = np.arange(len(self.labels)).astype(np.int64)
        if shuffle:
            self.features, self.labels, self.ids = sklearn.utils.shuffle(self.features,
                                                                         self.labels,
                                                                         self.ids, random_state=seed)
        if balanced:
            x_train_0, y_train_0 = self.features[np.where(self.labels == 0)[0]], self.labels[np.where(self.labels == 0)[0]]
            ids_0 = self.ids[np.where(self.labels == 0)[0]]
            x_train_1, y_train_1 = self.features[np.where(self.labels == 1)[0]], self.labels[np.where(self.labels == 1)[0]]
            ids_1 = self.ids[np.where(self.labels == 1)[0]]
            # randomly sample some data to make the dataset looks more equally distributed.
            x_train_1, y_train_1, ids_1 = x_train_1[:int(19713 * delta_1), :], y_train_1[:int(19713 * delta_1)], ids_1[:int(19713 * delta_1)]
            x_train_2, y_train_2 = self.features[np.where(self.labels == 2)[0]], self.labels[np.where(self.labels == 2)[0]]
            ids_2 = self.ids[np.where(self.labels == 2)[0]]
            # randomly sample some data to make the dataset looks more equally distributed.
            x_train_2, y_train_2, ids_2 = x_train_2[:int(19713 * delta_2), :], y_train_2[:int(19713 * delta_2)], ids_2[:int(19713 * delta_2)]
            self.features = np.concatenate([x_train_0, x_train_1, x_train_2], axis=0)
            self.labels = np.concatenate([y_train_0, y_train_1, y_train_2], axis=0)
            self.ids = np.concatenate([ids_0, ids_1, ids_2], axis=0)
            self.features, self.labels, self.ids = sklearn.utils.shuffle(self.features,
                                                                         self.labels,
                                                                         self.ids, random_state=seed)
        self.seed = seed

    def __len__(self):
        return self.features.shape[0]

    def __getitem__(self, idx):
        feature, label, index = self.features[idx], self.labels[idx], self.ids[idx]
        # feature = [_encode(i) for i in feature]
        return {'feature': torch.tensor(feature, dtype=torch.float32),
                'label': torch.tensor(label, dtype=torch.int64),
                'index': torch.tensor(index, dtype=torch.int64)}


# Below contains all implemented nn models for this task.
class ANN(nn.Module):

    def __init__(self, layers: list, n_class: int = 3, drop_rate=0.2):
        super(ANN, self).__init__()
        linears = [nn.Linear(250, layers[0]), nn.Tanh(), nn.Dropout(p=drop_rate)]
        # pay attention to here
        for i in range(len(layers) - 1):
            linears.append(nn.Linear(layers[i], layers[i + 1]))
            linears.append(nn.Tanh())
            linears.append(nn.Dropout(p=drop_rate))
        self.features = nn.Sequential(*linears)
        self.out = nn.Linear(layers[-1], n_class)

    def forward(self, x):
        out = self.features(x)
        logits = self.out(out)
        return logits


class IdentityBlock1D(nn.Module):
    def __init__(self, in_channels, out_channels, down_sample=False, kernel_size=3, padding=1, **kwargs):
        super(IdentityBlock1D, self).__init__()
        self.down_sample = down_sample
        stride = 2 if down_sample else 1
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride)
        # The first conv layer has to follow the original paper so as to make sure feature map downsampled correctly.
        self.norm1 = nn.BatchNorm1d(out_channels)
        self.non_linear1 = nn.ReLU()
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=padding)
        self.norm2 = nn.BatchNorm1d(out_channels)
        self.non_linear2 = nn.ReLU()
        if down_sample:
            self.residual = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, padding=0, stride=2),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        # obviously, up-conv layer must be applied before identity block in decoder block.
        if self.down_sample:
            identity = self.residual(x)
        else:
            identity = x

        out = self.conv1(x)
        out = self.norm1(out)
        out = self.non_linear1(out)

        out = self.conv2(out)
        out = self.norm2(out)

        out += identity
        out = self.non_linear2(out)

        return out


class ResNet1D(nn.Module):
    def _make_stage(self, in_c, out_c, down_s=True, num_l=2):
        stage_list = [IdentityBlock1D(in_c, out_c, down_s)]
        for _ in range(1, num_l):
            stage_list.append(IdentityBlock1D(out_c, out_c))
        return nn.Sequential(*stage_list)

    def _configure_layers(self, stage=18):
        if stage == 18:
            return [2, 2, 2, 2]
        else:
            return [3, 4, 6, 3]

    def __init__(self, in_channels: int, layers: list, n_class: int = 3, stages: int = 18):
        super(ResNet1D, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=layers[0], kernel_size=7, stride=2, padding=3)
        self.norm1 = nn.BatchNorm1d(layers[0])
        self.non_linear1 = nn.ReLU()
        self.max_pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

        config = self._configure_layers(stages)
        self.stage1 = self._make_stage(layers[0], layers[0], down_s=False, num_l=config[0])
        self.stage2 = self._make_stage(layers[0], layers[1], down_s=True, num_l=config[1])
        self.stage3 = self._make_stage(layers[1], layers[2], down_s=True, num_l=config[2])
        self.stage4 = self._make_stage(layers[2], layers[3], down_s=True, num_l=config[3])

        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(layers[3], n_class)

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.non_linear1(x)
        x = self.max_pool(x)

        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)

        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


class ResNet1D_M(nn.Module):
    def _make_stage(self, in_c, out_c, down_s=True, num_l=2, ks=3, pa=1):
        stage_list = [IdentityBlock1D(in_c, out_c, down_s, ks, pa)]
        for _ in range(1, num_l):
            stage_list.append(IdentityBlock1D(out_c, out_c, kernel_size=ks, padding=pa))
        return nn.Sequential(*stage_list)

    def _configure_layers(self, stage=18):
        if stage == 18:
            return [2, 2, 2, 2]
        else:
            return [3, 4, 6, 3]

    def __init__(self, in_channels: int, layers: list, n_class: int = 3, stages: int = 18, kernel_size=3, padding=1,
                 drop_rate=0.0):
        super(ResNet1D_M, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=layers[0], kernel_size=7, stride=1, padding=3)
        self.norm1 = nn.BatchNorm1d(layers[0])
        self.non_linear1 = nn.ReLU()
        # self.max_pool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

        config = self._configure_layers(stages)
        self.stage1 = self._make_stage(layers[0], layers[0], down_s=False, num_l=config[0], ks=kernel_size, pa=padding)
        self.stage2 = self._make_stage(layers[0], layers[1], down_s=True, num_l=config[1], ks=kernel_size, pa=padding)
        self.stage3 = self._make_stage(layers[1], layers[2], down_s=True, num_l=config[2], ks=kernel_size, pa=padding)
        self.stage4 = self._make_stage(layers[2], layers[3], down_s=True, num_l=config[3], ks=kernel_size, pa=padding)

        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(layers[3], n_class)
        if drop_rate:
            self.drop = nn.Dropout(p=drop_rate)
        else:
            self.drop = None

    def forward(self, x: torch.Tensor):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.non_linear1(x)
        # x = self.max_pool(x)

        x = self.stage1(x)
        if isinstance(self.drop, nn.Dropout):
            x = self.drop(x)
        x = self.stage2(x)
        if isinstance(self.drop, nn.Dropout):
            x = self.drop(x)
        x = self.stage3(x)
        if isinstance(self.drop, nn.Dropout):
            x = self.drop(x)
        x = self.stage4(x)

        x = self.pool(x)
        if isinstance(self.drop, nn.Dropout):
            x = self.drop(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


class GRU(nn.Module):
    # please note that GRU does not accept one_hot data.
    def __init__(self, embedding_dim: int, hidden_dim: int, num_layers: int = 2, n_class: int = 3, drop_rate=0.2):
        # embedding_dim refers to the dimension of the embedding vector to better encode the input vector.
        """
        Pay attention:
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False``
        """
        super(GRU, self).__init__()
        self.embedding = nn.Embedding(4, embedding_dim)  # ATGC
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers, dropout=drop_rate,
                          batch_first=True)
        self.fc = nn.Sequential(
            # nn.Linear(hidden_dim, hidden_dim),
            # nn.Dropout(0.5),
            # nn.ReLU(),
            nn.Linear(hidden_dim, n_class)
        )

    def forward(self, x):
        # x : [batch, sequence]
        embeds = self.embedding(x)
        # embeds : [batch, sequence, embedding_dim]
        r_out, _ = self.gru(embeds, None)
        # r_out : [batch, sequence, hidden_dim]
        # suggestion from poncey.
        # out = self.fc(r_out[:, -1, :])
        r_out = torch.mean(r_out, dim=1)
        out = self.fc(r_out)
        # out : [batch, sequence, output_dim]
        return out




## Below contains the training function
### _Please note that the config recorded in this section only corresponds to GRU's best performance, for the config of other models, please contact wilszhang2-c@my.cityu.edu.hk._

In [None]:
import torch
import os
import socket
import numpy as np
import pandas as pd

from torch.utils.tensorboard import SummaryWriter
from torch.optim import SGD, Adam, lr_scheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
from datetime import datetime

config_dict = dict(
    seed=114,
    balanced=True,
    delta_1=2,
    delta_2=2,
    class_weight=[2, 1, 1],
    model='gru',
    # this entry works for ANN and CNN
    layers=[16, 32, 64, 128],
    # these entries work for CNN only
    stages=18,
    kernel_size=3,
    padding=1,
    # these entries work for GRU only
    embedding_dim=100,
    hidden_dim=256,
    num_layers=3,
    # this entry works for all networks
    drop_rate=0.2,
    epochs=30,
    # optimizer
    optim='adam',
    lr=2e-04,
    decay=True,
    batch_size=1024,
    one_hot=False,
)

SEED = config_dict['seed']
TIME = datetime.now().strftime('%b%d_%H-%M-%S')
CLASS_WEIGHT = torch.tensor(config_dict['class_weight'], dtype=torch.float)
DEVICE = 'cuda:0'
BALANCED = config_dict['balanced']
# LAYERS = config_dict['layers']
EPOCH = config_dict['epochs']
LR = config_dict['lr']
BATCH_SIZE = config_dict['batch_size']
MODEL_DICT = None
PARALLEL = False
DECAY = config_dict['decay']
DELTA_1 = config_dict['delta_1']
DELTA_2 = config_dict['delta_2']
ONE_HOT = config_dict['one_hot']  # when using GRU and CNN, this should be false.

if not ONE_HOT:                                                                                                   
    train_data = SequenceDataset(x='./train_expand.npz', shuffle=True, seed=SEED, balanced=BALANCED,              
                                 delta_1=DELTA_1, delta_2=DELTA_2)                                                
    val_data = SequenceDataset(x='./val_expand.npz', shuffle=False, balanced=False)                               
    test_data = SequenceDataset(x='./test_expand.npz', shuffle=False, balanced=False)                             
else:                                                                                                             
    train_data = SequenceDataset(x='./train_one_hot.npz', shuffle=True, seed=SEED, balanced=BALANCED,             
                                 delta_1=DELTA_1, delta_2=DELTA_2)                                                
    val_data = SequenceDataset(x='./val_one_hot.npz', shuffle=False, balanced=False)                              
    test_data = SequenceDataset(x='./test_one_hot.npz', shuffle=False, balanced=False)                            
# create logger and model path.                                                                                   
if not os.path.exists('./runs/{}'.format(config_dict['model'])):                                                  
    os.makedirs('./runs/{}'.format(config_dict['model']))                                                         
if SEED:
    set_seed(SEED)
logger = create_logger('./runs/{}/{}_{}.log'.format(config_dict['model'], TIME, socket.gethostname()))            
if 'cuda' in DEVICE:
    if not torch.cuda.is_available():                                                                             
        raise ValueError('CUDA specified but not detected.')                                                      
    else:                                                                                                         
        torch.backends.cudnn.deterministic = True                                                                 
        torch.backends.cudnn.benchmark = False                                                                    
logger.info('%-40s %s\n' % ('Using device', DEVICE))                                                              
if config_dict['model'] == 'resnet_m':
    model = ResNet1D_M(in_channels=4 if ONE_HOT else 1, layers=config_dict['layers'], n_class=3,                  
                       stages=config_dict['stages'], kernel_size=config_dict['kernel_size'],                      
                       padding=config_dict['padding'], drop_rate=config_dict['drop_rate'])                        
elif config_dict['model'] == 'ann':
    model = ANN(layers=config_dict['layers'], drop_rate=config_dict['drop_rate'])                                 
else:
    model = GRU(embedding_dim=config_dict['embedding_dim'], hidden_dim=config_dict['hidden_dim'],                 
                num_layers=config_dict['num_layers'], drop_rate=config_dict['drop_rate'])                         
if PARALLEL:                                                                                                      
    model = nn.DataParallel(model)                                                                                
if MODEL_DICT:                                                                                                    
    checkpoint = torch.load(MODEL_DICT)                                                                           
    model.load_state_dict(checkpoint['model_state_dict'])                                                         
model = model.to(device=DEVICE)                                                                                   
CLASS_WEIGHT = CLASS_WEIGHT.to(device=DEVICE)                                                                     
criterion = nn.CrossEntropyLoss(weight=CLASS_WEIGHT)                                                              
# criterion = FocalLoss()                                                                                         
# optimizer = Adam(model.parameters(), LR)
# if you found the performance is not equivalent to what we report on the paper, try to cancel weight decay.
if config_dict['optim'] == 'sgd':                                                                                 
    optimizer = SGD(model.parameters(), LR, momentum=0.99, weight_decay=5e-04)                                    
else:                                                                                                             
    optimizer = Adam(model.parameters(), LR, weight_decay=5e-04)                                                  
                                                                                                                  
logger.info(' Config '.center(80, '-'))                                                                           
name_format = '%-40s %s\n' * 12                                                                                   
logger.info(name_format % ("Learning Rate", LR,                                                                   
                           "Decay", DECAY,                                                                        
                           "Balanced Set", BALANCED,                                                              
                           "Loss Function", criterion,                                                            
                           "Class Weight", CLASS_WEIGHT.clone().detach().cpu().numpy(),                           
                           "Epoch", EPOCH,                                                                        
                           "Batch Size", BATCH_SIZE,                                                              
                           'Seed', SEED,                                                                          
                           "One-hot feature set", ONE_HOT,                                                        
                           "Training set size", train_data.__len__(),                                             
                           "Validation Set Size", val_data.__len__(),                                             
                           "Test Set Size", test_data.__len__()))                                                 
logger.info(f'\t{model}')                                                                                         
logger.info('-' * 80)                                                                                             
if DECAY:                                                                                                         
    scheduler = lr_scheduler.MultiStepLR(optimizer, [cc for cc in range(1, EPOCH, 1)], gamma=0.99)                
else:                                                                                                             
    scheduler = None                                                                                              
                                                                                                                  
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)  # already shuffled.    
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)                            
                                                                                                                  
writer = SummaryWriter(log_dir='./runs/{}/{}_{}/'.format(config_dict['model'], TIME, socket.gethostname()))       
global_step = 0                                                                                                   
# start training.                                                                                                 
for epoch in range(1, EPOCH + 1):                                                                                 
    model.train()                                                                                                 
    with tqdm(total=train_data.__len__(), desc=f'Epoch {epoch}/{EPOCH}', unit='seqs') as pbar:                    
        for batch in train_loader:                                                                                
            features = batch['feature'].to(device=DEVICE)                                                         
            if isinstance(model, GRU):                                                                            
                features = features.long()                                                                        
            labels = batch['label'].to(device=DEVICE)                                                             
                                                                                                                  
            with torch.no_grad():                                                                                 
                if isinstance(model, ANN) or isinstance(model, GRU):                                              
                    features = torch.squeeze(features)                                                            
            pred = model(features)                                                                                
            loss = criterion(pred, labels)                                                                        
            writer.add_scalar('loss/train_criterion', loss.item(), global_step)                                   
            pbar.set_postfix(**{'loss (batch)': loss.item()})                                                     
                                                                                                                  
            optimizer.zero_grad()                                                                                 
            loss.backward()                                                                                       
            nn.utils.clip_grad_value_(model.parameters(), 0.1)                                                    
            optimizer.step()                                                                                      
                                                                                                                  
            pbar.update(features.shape[0])                                                                        
            global_step += 1                                                                                      
                                                                                                                  
        if DECAY:                                                                                                 
            scheduler.step()                                                                                      
            writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch)                            
    if not epoch % 4 or epoch == EPOCH:                                                                           
        mean_acc_0, mean_acc_1, mean_acc_2 = eval_net(model, val_loader, DEVICE)                                  
        logger.info('Validation Accuracy of class 0 for epoch {}: {}'.format(epoch, mean_acc_0))                  
        logger.info('Validation Accuracy of class 1 for epoch {}: {}'.format(epoch, mean_acc_1))                  
        logger.info('Validation Accuracy of class 2 for epoch {}: {}'.format(epoch, mean_acc_2))                  
        writer.add_scalar('val_acc/0', mean_acc_0, epoch)                                                         
        writer.add_scalar('val_acc/1', mean_acc_1, epoch)                                                         
        writer.add_scalar('val_acc/2', mean_acc_2, epoch)                                                         
                                                                                                                  
# training finished, start output evaluation results.                                                             
logger.info("Saving testing set predictions for epoch {}".format(EPOCH))                                          
test_loader = DataLoader(test_data, batch_size=200, shuffle=False)                                                
model.eval()                                                                                                      
outputs = []                                                                                                      
out_ids = []                                                                                                      
n_test = len(test_loader)  # the number of batch                                                                  
with tqdm(total=n_test, desc='Output Testing result', unit='batch', leave=False) as pbar:                         
    for batch in test_loader:                                                                                     
        features, ids = batch['feature'].to(device=DEVICE), batch['index'].to(device='cpu')                       
        if isinstance(model, GRU):                                                                                
            features = features.long()                                                                            
                                                                                                                  
        with torch.no_grad():                                                                                     
            if isinstance(model, ANN) or isinstance(model, GRU):                                                  
                features = torch.squeeze(features)                                                                
            preds = model(features)                                                                               
            preds = F.softmax(preds, dim=1)                                                                       
            preds = torch.argmax(preds, dim=1, keepdim=True)                                                      
        outputs.append(preds.detach().cpu().numpy())                                                              
        out_ids.append(ids.detach().cpu().numpy())                                                                
        pbar.update()                                                                                             
outputs = np.concatenate(outputs, axis=0)                                                                         
out_ids = np.concatenate(out_ids, axis=0)  # debug needed                                                         
res = {'ID': out_ids, 'label': outputs[:, 0]}                                                                     
df = pd.DataFrame(data=res, dtype=np.int)                                                                         
df.to_csv('./runs/{}/{}_{}/{}_epoch_{}.csv'.format(config_dict['model'], TIME, socket.gethostname(),              
                                                   model.__class__.__name__, EPOCH),                              
          index=False)                                                                                            
logger.info("Saving model for final epoch {}".format(EPOCH))                                                      
torch.save({                                                                                                      
    'epoch': EPOCH,                                                                                               
    'model_state_dict': model.state_dict()                                                                        
}, './runs/{}/{}_{}/{}_epoch_{}.pt'.format(config_dict['model'], TIME, socket.gethostname(),                      
                                           model.__class__.__name__, EPOCH))                                      
np.save('./runs/{}/{}_{}/config_dict.npy'.format(config_dict['model'], TIME, socket.gethostname()), config_dict)  
writer.close()