#Prepare

In [None]:
!pip install transformers

In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import f1_score
from sklearn.preprocessing import LabelEncoder
import random
import os

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm.notebook import tqdm
from transformers import AutoTokenizer
from transformers import AutoModel
from torch.optim import AdamW
import torch.optim as optim

import warnings
warnings.filterwarnings(action='ignore')

#Optimizer

In [None]:
class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                e_w = p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"
                self.state[p]["e_w"] = e_w

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.sub_(self.state[p]["e_w"])  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        p.grad.norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

#Train

In [None]:
#SAM

def train(model, optimizer, train_loader, test_loader, device):

    # model.load_state_dict(torch.load(model_path))
    model.to(device)
    criterion = nn.CrossEntropyLoss().to(device)

    best_score = 0
    best_model = "None"

    epoch_step = 0

    for epoch_num in range(CFG["EPOCHS"]):
        
        model.train()

        train_loss = []

        for input_ids, attention_mask, train_label in tqdm(train_loader):

            optimizer.zero_grad()

            train_label = train_label.to(device)
            input_id = input_ids.to(device)
            mask = attention_mask.to(device)

            output = model(input_id, mask)     
            

            loss1 = criterion(output, train_label.long()) 
            loss1.backward(retain_graph=True)
            optimizer.first_step(zero_grad=True)
            
            loss2 = criterion(model(input_id, mask), train_label.long())
            loss2.backward()  
            optimizer.second_step(zero_grad=True)

            train_loss.append(loss2.item())
            
        epoch_step += 1

        val_loss, val_score = validation(model, criterion, test_loader, device)

        # scheduler.step(float(np.mean(val_loss)))

        print(f'Epoch [{epoch_step}], Train Loss : [{np.mean(train_loss) :.5f}] Val Loss : [{np.mean(val_loss) :.5f}] Val F1 Score : [{val_score:.5f}]')

        model_saved_path = './path' + str(epoch_step) + '.pt'

        torch.save(model.state_dict(), model_saved_path)

        if best_score < val_score:
            best_model = model
            best_score = val_score
        
    return best_model                         

In [None]:
#SAM
model = BaseModel()
base_optimizer = torch.optim.AdamW  
optimizer = SAM(model.parameters(), base_optimizer, lr=CFG['LEARNING_RATE'])
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode = 'min', factor = 0.1, patience = 1, threshold = 1e-3, verbose = True)
model.eval()

infer_model = train(model, optimizer, train_dataloader, val_dataloader, device)