In [1]:
import gc

def train_model(model,train_loader,val_loader,loss,optimizer, num_epochs,writer,epsilon):
    loss_history = []
    train_history = []
    val_history = []
    val_loss_hist = []
    metric_y_val = metric_p_val = None
    
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
                                                  T_0=40, 
                                                  T_mult=2,
                                                  eta_min=1e-9)
    for epoch in range(num_epochs):
        
        print(epoch)
        model.train()
        
        correct_samples = 0
        total_samples = 0
        loss_accum = 0
        
        for i_step, (x, y) in tqdm(enumerate(train_loader)):
                
                
                #train on simple sample
                
                x_gpu = x.to(device=device)
                y_gpu = y.to(device=device)
                
                # Set requires_grad attribute of tensor. Important for Attack
                x_gpu.requires_grad = True
                
                prediction = model(x_gpu)
                loss_value = loss(prediction, y_gpu.reshape((-1,)))
                _, preds = torch.max(prediction, 1)
                preds = preds.cpu()
                

                optimizer.zero_grad()
                loss_value.backward()
                optimizer.step()
                
                #traning on FGSM sample
                
                # Collect datagrad
                data_grad = x_gpu.grad.data.cpu()
                
                # Call FGSM Attack
                perturbed_data = fgsm_attack(x_gpu.cpu(), epsilon, data_grad)
                
                del x_gpu
                
                prediction_fgsm = model(perturbed_data.to(device=device))
                
                loss_value = loss(prediction_fgsm, y_gpu.reshape((-1,)))
                
                _, preds_fgsm = torch.max(prediction_fgsm, 1)
                preds_fgsm = preds_fgsm.cpu()
                
                
                optimizer.zero_grad()
                loss_value.backward()
                optimizer.step()
                
                y_cpu = y_gpu.cpu()
                
                
                preds = np.concatenate((preds.numpy(), preds_fgsm.numpy()))
                labels = np.concatenate((y_cpu.reshape((-1,)).numpy(), y_cpu.reshape((-1,)).numpy()))
                

                
                if i_step == 0 and epoch == 0:
                    metric_y = labels
                    metric_p = preds
                else:
                    metric_y = np.concatenate((metric_y, labels))
                    metric_p = np.concatenate((metric_p, preds)) 
                    
                correct_samples += torch.sum(torch.tensor(preds) == torch.tensor(labels))
                loss_accum += float(loss_value.cpu())
                
                total_samples += y_cpu.shape[0]*2
                del y_gpu
                del prediction_fgsm
                del perturbed_data
                gc.collect()
        
        
        ave_loss = loss_accum / ((i_step + 1)*2)
        train_accuracy = correct_samples / total_samples
        writer.add_scalar("Loss/train", ave_loss, epoch)
        writer.add_scalar("Acc/train", train_accuracy, epoch)
        writer.add_scalar("F1/train", f1_score(metric_y,metric_p), epoch)

        val_accuracy, loss_val,metric_y_val, metric_p_val = compute_valid(model, val_loader, loss,epoch,metric_y_val, metric_p_val)
        writer.add_scalar("Loss/valid", loss_val, epoch)
        writer.add_scalar("Acc/valid", val_accuracy, epoch)
        writer.add_scalar("F1/valid", f1_score(metric_y_val,metric_p_val), epoch)
        
        writer.add_scalar("Lr/epoch", scheduler.get_last_lr()[-1], epoch)
        scheduler.step(epoch)
        
        loss_history.append(float(ave_loss))
        train_history.append(train_accuracy)
        val_history.append(val_accuracy)
        val_loss_hist.append(loss_val)

        print("Average loss: %f, Val loss: %f, Train accuracy: %f, Val accuracy: %f, Train AP: %f,Val AP: %f" % (ave_loss,loss_val, train_accuracy, val_accuracy,average_precision_score(metric_y,metric_p),average_precision_score(metric_y_val,metric_p_val)))
        
        print('Epoch:', epoch, 'LR:', scheduler.get_last_lr())
        

def compute_valid(model, loader, loss,epoch, metric_y= None, metric_p=None):
    model.eval()
    with torch.no_grad():
        correct_samples = 0
        total_samples = 0
        loss_accum = 0
        
        for i_step, (x, y) in enumerate(loader):
            x_gpu = x.to(device=device, dtype=torch.float)
            y_gpu = y.to(device=device,)

            prediction = model(x_gpu)
            loss_value = loss(prediction, y_gpu.reshape((-1,)))
            _, preds = torch.max(prediction, 1)
            
            
            if i_step == 0 and epoch == 0:
                metric_y = y_gpu.reshape((-1,)).cpu().numpy()
                metric_p = preds.cpu().numpy()
            else:
                metric_y = np.concatenate((metric_y, y_gpu.reshape((-1,)).cpu().numpy()))
                metric_p = np.concatenate((metric_p, preds.cpu().numpy())) 
            
            
            correct_samples += torch.sum(preds == y_gpu.reshape((-1,)))
            total_samples += y_gpu.shape[0]
            loss_accum += loss_value

            del x_gpu
            del y_gpu
                
        loss_val = loss_accum / (i_step + 1)
        val_accuracy = correct_samples / total_samples
        return val_accuracy, loss_val, metric_y, metric_p
    