<a href="https://colab.research.google.com/github/guijinSON/sgMLP_Implementation/blob/main/finetue_sst2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
!pip install transformers
!pip install datasets
!git clone https://github.com/guijinSON/sgMLP_Implementation.git

In [None]:
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
import os 
from datasets import load_dataset
from tqdm import tqdm
from transformers import AutoTokenizer
os.chdir('/content/sgMLP_Implementation')
#print(os.listdir())
#from CLS_model.model import build_model
from models.model import build_model

In [None]:
%%capture
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
train_dataset = load_dataset('glue', 'sst2', split='train')
val_dataset = load_dataset('glue', 'sst2', split='validation')


In [None]:
device = torch.device('cuda:0')

train_dataset = train_dataset.map(lambda e: tokenizer(e['sentence'],max_length=64,truncation=True, padding='max_length'), batched=True)
train_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'label'])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

val_dataset = val_dataset.map(lambda e: tokenizer(e['sentence'],max_length=64,truncation=True, padding='max_length'), batched=True)
val_dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'label'])
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle =True)

Loading cached processed dataset at /root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-05772391df13fcad.arrow


  0%|          | 0/1 [00:00<?, ?ba/s]

In [None]:
class SST2_head(nn.Module):
    def __init__(self,device):
        super(SST2_head,self).__init__()
        self.pooler = nn.Linear(512,512).to(device)
        self.projection = nn.Sequential(nn.Linear(512,512*2),
                                        nn.ReLU(),
                                        nn.Linear(512*2,1)).to(device)
        self.layernorm = nn.LayerNorm(512).to(device)
        self.sigmoid = nn.Sigmoid()

        self.model = self.load_model().to(device)

    def forward(self,input_ids,token_type_ids):
        input = self.model(input_ids,token_type_ids)[:,0]
        input = torch.tanh(self.pooler(input))
        input = self.layernorm(input)
        output = self.projection(input)
        return self.sigmoid(output)

    def load_model(self):
        model = build_model(tokenizer.vocab_size,512,2048,64,12,device='cuda:0',output_logits=False)
        os.chdir('/content/drive/Shareddrives/ICT/weights/')
        weight = torch.load('iter_110000.pth',map_location=torch.device('cuda:0'))['model_state_dict']
        model_weight = {}
        for key,val in weight.items():
            if key.startswith('module.'):
                model_weight[key[7:]] = val
            else:
                print(key)

        model.load_state_dict(model_weight)
        return model


In [None]:
sst2_model = SST2_head('cuda:0')
epochs = 10 
loss_func = nn.BCELoss(reduction='mean')
opt = torch.optim.Adam(sst2_model.parameters(),lr=5e-5)
scheduler = StepLR(opt, step_size=2000, gamma=0.5)
step = 0
for iteration in range(epochs):
    losses = []
    loss = 0
    for batch in tqdm(train_dataloader):
        input_ids = batch['input_ids'].to(device)
        token_ids = batch['token_type_ids'].to(device)
        label = batch['label'].type(torch.FloatTensor).to(device)
        
        opt.zero_grad()

        pred = sst2_model(input_ids,token_ids).squeeze()
        loss = loss_func(pred,label)

       
        loss.backward()
        opt.step()
        scheduler.step()

        losses.append(loss.detach().item())

        if step%500 ==0:
            sst2_model.eval()
            with torch.no_grad():
                accuracy = 0
                tot = 0
                for val_batch in val_dataloader:
                    input_ids = val_batch['input_ids'].to(device)
                    token_ids = val_batch['token_type_ids'].to(device)
                    label = val_batch['label'].type(torch.FloatTensor).to(device)

                    pred = sst2_model(input_ids,token_ids).squeeze()
                    acc = torch.tensor([1 if n >0.5 else 0 for n in pred]).to(device)
                    accuracy += torch.sum(acc==label)
                    tot += len(label)

                print(f'{iteration} Epoch Running | {step} Step | Loss: {loss.detach().item():.3f} | Accuracy: {accuracy/tot*100 :.2f}')
            sst2_model.train()

        step+=1 



  0%|          | 2/1053 [00:01<08:22,  2.09it/s]

0 Epoch Running | 0 Step | Loss: 0.680 | Accuracy: 51.03


 48%|████▊     | 502/1053 [01:10<02:44,  3.35it/s]

0 Epoch Running | 500 Step | Loss: 0.449 | Accuracy: 74.54


 95%|█████████▌| 1002/1053 [02:20<00:15,  3.34it/s]

0 Epoch Running | 1000 Step | Loss: 0.636 | Accuracy: 74.89


100%|██████████| 1053/1053 [02:27<00:00,  7.14it/s]
 43%|████▎     | 449/1053 [01:02<03:01,  3.32it/s]

1 Epoch Running | 1500 Step | Loss: 0.186 | Accuracy: 78.78


 90%|█████████ | 949/1053 [02:12<00:31,  3.33it/s]

1 Epoch Running | 2000 Step | Loss: 0.224 | Accuracy: 79.82


100%|██████████| 1053/1053 [02:27<00:00,  7.16it/s]
 38%|███▊      | 396/1053 [00:55<03:17,  3.32it/s]

2 Epoch Running | 2500 Step | Loss: 0.194 | Accuracy: 80.50


 85%|████████▌ | 896/1053 [02:05<00:47,  3.33it/s]

2 Epoch Running | 3000 Step | Loss: 0.197 | Accuracy: 80.85


100%|██████████| 1053/1053 [02:27<00:00,  7.16it/s]
 33%|███▎      | 343/1053 [00:47<03:30,  3.38it/s]

3 Epoch Running | 3500 Step | Loss: 0.284 | Accuracy: 81.77


 80%|████████  | 843/1053 [01:57<01:02,  3.36it/s]

3 Epoch Running | 4000 Step | Loss: 0.194 | Accuracy: 81.08


100%|██████████| 1053/1053 [02:25<00:00,  7.22it/s]
 28%|██▊       | 290/1053 [00:40<03:49,  3.33it/s]

4 Epoch Running | 4500 Step | Loss: 0.058 | Accuracy: 81.19


 75%|███████▌  | 790/1053 [01:50<01:18,  3.36it/s]

4 Epoch Running | 5000 Step | Loss: 0.071 | Accuracy: 80.73


100%|██████████| 1053/1053 [02:26<00:00,  7.20it/s]
 23%|██▎       | 237/1053 [00:33<04:03,  3.35it/s]

5 Epoch Running | 5500 Step | Loss: 0.129 | Accuracy: 81.77


 70%|██████▉   | 737/1053 [01:43<01:35,  3.32it/s]

5 Epoch Running | 6000 Step | Loss: 0.112 | Accuracy: 82.00


100%|██████████| 1053/1053 [02:26<00:00,  7.17it/s]
 17%|█▋        | 184/1053 [00:26<04:19,  3.35it/s]

6 Epoch Running | 6500 Step | Loss: 0.035 | Accuracy: 80.73


 65%|██████▍   | 684/1053 [01:35<01:50,  3.33it/s]

6 Epoch Running | 7000 Step | Loss: 0.062 | Accuracy: 80.96


100%|██████████| 1053/1053 [02:26<00:00,  7.19it/s]
 12%|█▏        | 131/1053 [00:18<04:35,  3.35it/s]

7 Epoch Running | 7500 Step | Loss: 0.043 | Accuracy: 80.73


 60%|█████▉    | 631/1053 [01:28<02:05,  3.37it/s]

7 Epoch Running | 8000 Step | Loss: 0.023 | Accuracy: 81.08


100%|██████████| 1053/1053 [02:26<00:00,  7.20it/s]
  7%|▋         | 78/1053 [00:11<04:57,  3.27it/s]

8 Epoch Running | 8500 Step | Loss: 0.106 | Accuracy: 81.31


 55%|█████▍    | 578/1053 [01:21<02:22,  3.34it/s]

8 Epoch Running | 9000 Step | Loss: 0.039 | Accuracy: 80.16


100%|██████████| 1053/1053 [02:26<00:00,  7.18it/s]
  2%|▏         | 25/1053 [00:04<05:10,  3.31it/s]

9 Epoch Running | 9500 Step | Loss: 0.033 | Accuracy: 81.31


 50%|████▉     | 525/1053 [01:14<02:40,  3.30it/s]

9 Epoch Running | 10000 Step | Loss: 0.061 | Accuracy: 80.05


 97%|█████████▋| 1025/1053 [02:24<00:08,  3.35it/s]

9 Epoch Running | 10500 Step | Loss: 0.099 | Accuracy: 80.50


100%|██████████| 1053/1053 [02:28<00:00,  7.11it/s]


In [None]:
## hyper parameters
#1 LR: 5e-5 | 5e-4 | Steps: 5500 | Optimizer: Adam | Batch : 32
## Accuracy: 80.73

#2 LR: 5e-5 | 5e-4 | Steps: 3000 | Optimizer: Adam | Batch : 64
## Accuracy: 79.36

#3 LR: 5e-5 | 5e-5 | Steps: 3000 | Optimizer: Adam | Batch : 64 | Smaller Linear Layer
## Accuracy: 80.96

#5 LR: 5e-5 | 5e-5 | Steps: 3000 | Optimizer: Adam | Batch : 64 | Smaller Linear Layer | Modified Pooler function
## Accuracy: 81.08

#6 LR: 5e-5 | 5e-5 | Steps: 3000 | Optimizer: Adam | Batch : 64 | Smaller Linear Layer | Modified Pooler function | Added Layer Norm
## Accuracy: 81.88

#7 LR: 5e-5 | 5e-5 | Steps: 3000 | Optimizer: Adam | Batch : 64 | Smaller Linear Layer | Modified Pooler function | Added Layer Norm | Added Scheduler
## Accuracy: 79

#8 LR: 5e-5 | 5e-5 | Steps: 3000 | Optimizer: Adam | Batch : 64 | Smaller Linear Layer | Modified Pooler function | Added Layer Norm | Added Scheduler(Step LR 2000/0.5) | Added ReLU
## Accuracy: 82