In [4]:
import os
import sys

import time
import random
from unittest.mock import create_autospec
import warnings
import numpy as np
import argparse
import pandas as pd

import torch
import torch.nn as nn

from torch.optim import lr_scheduler, AdamW

from dataset import * #create_boolq_dataset_object, create_dataset_object, load_agnews_dataset, load_boolq_dataset, load_cb_dataset, load_imdb_dataset, load_topic_dataset, load_yelp_dataset
from dataloader import get_dataloaders

from train import test_model, train_model

from prompt import PROMPTEmbedding
from model import APT
from utils import freeze_params_encoder, get_accuracy, count_parameters, freeze_params

from transformers import RobertaTokenizer, BertTokenizer, BertForSequenceClassification, RobertaForSequenceClassification
from transformers import get_linear_schedule_with_warmup, logging

In [5]:
dataset = 'imdb'   #imdb

model_type = 'bert-base-cased'   #roberta

number_of_tokens = 20

mode = 'head'


batch_size = 16

learning_rate = 2e-5

epochs = 10



In [6]:
tokenizer = BertTokenizer.from_pretrained(model_type)

#tokenizer = RobertaTokenizer.from_pretrained(model_type)

train_text, train_labels, test_text, test_labels, valid_text, valid_labels = load_imdb_dataset(dataset)

train_data_object = create_dataset_object(train_text, train_labels, number_of_tokens, tokenizer, dataset,  mode)

test_data_object  = create_dataset_object(test_text, test_labels, number_of_tokens, tokenizer, dataset, mode)

val_data_object = create_dataset_object(valid_text, valid_labels, number_of_tokens, tokenizer, dataset, mode)

dataloaders = get_dataloaders(train_data_object, test_data_object, val_data_object, batch_size)

num_labels = 2

print("IMDB dataset loaded succesfully\n")

IMDB dataset loaded succesfully



In [7]:
model = BertForSequenceClassification.from_pretrained(model_type, 
                                                    num_labels=num_labels,
                                                    output_attentions=False,
                                                    output_hidden_states=False)


"""model = RobertaForSequenceClassification.from_pretrained(model_type, 
                                                    num_labels=num_labels,
                                                    output_attentions=False,
                                                    output_hidden_states=False)
"""

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

'model = RobertaForSequenceClassification.from_pretrained(model_type, \n                                                    num_labels=num_labels,\n                                                    output_attentions=False,\n                                                    output_hidden_states=False)\n'

In [12]:
model =  freeze_params_encoder(model, model_type)

print("Model for head finetuning loaded successfully\n")

Model for head finetuning loaded successfully



In [26]:
count_parameters(model)

+--------------------------+------------+
|         Modules          | Parameters |
+--------------------------+------------+
| bert.pooler.dense.weight |   589824   |
|  bert.pooler.dense.bias  |    768     |
+--------------------------+------------+
Total Trainable Params: 590592



In [25]:
for name, param in model.named_parameters():
	if 'classifier' not in name: # classifier layer
		if "pooler" not in name:
			param.requires_grad = False

In [None]:
count_parameters(model)

In [14]:
roberta = freeze_params(model)

prompt_emb = PROMPTEmbedding(roberta.get_input_embeddings(), 
            n_tokens= number_of_tokens, 
            initialize_from_vocab=True)

roberta.set_input_embeddings(prompt_emb)

model = APT(roberta, 8, 1, model_type)

print("Roberta APT model loaded successfully\n")

Roberta APT model loaded successfully



In [6]:
bert = freeze_params(model)

prompt_emb = PROMPTEmbedding(bert.get_input_embeddings(), 
            n_tokens= number_of_tokens, 
            initialize_from_vocab=True)

bert.set_input_embeddings(prompt_emb)

model = APT(bert, 8, 1, model_type)

print("Bert APT model loaded successfully\n")

Bert APT model loaded successfully



In [7]:
# Check GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

#Loss function
criterion = nn.CrossEntropyLoss()

In [8]:
optimizer = AdamW(model.parameters(), lr = learning_rate, eps=1e-8)

In [9]:
scheduler = get_linear_schedule_with_warmup(
    optimizer, 
    num_warmup_steps=len(dataloaders['Train'])*epochs/15, 
    num_training_steps=len(dataloaders['Train'])*epochs
)


In [10]:
phase = 'Train'

model = model.to(device)

for idx, (data, labels) in enumerate(dataloaders[phase]):
    input_ids =  data['input_ids'].squeeze(1).to(device)
    attention_mask = data['attention_mask'].squeeze(1).to(device)
    
    
    labels = labels.to(device)


    output = model(input_ids = input_ids, attention_mask = attention_mask)


    print(output.shape)
    break

torch.Size([16, 512, 2])


In [11]:
output

tensor([[[ 0.0043, -0.4028],
         [-0.0344, -0.2357],
         [-0.0636, -0.2666],
         ...,
         [ 0.0172, -0.3997],
         [ 0.0254, -0.2515],
         [ 0.0690, -0.1504]],

        [[ 0.0179, -0.4897],
         [-0.1077, -0.2106],
         [-0.1459, -0.2359],
         ...,
         [-0.1453, -0.3558],
         [-0.0319, -0.2295],
         [-0.0077, -0.1383]],

        [[-0.0101, -0.4599],
         [-0.1921, -0.2222],
         [-0.1778, -0.1626],
         ...,
         [-0.1719, -0.3139],
         [-0.0798, -0.1796],
         [-0.0667, -0.0880]],

        ...,

        [[-0.0723, -0.3978],
         [-0.0255, -0.1854],
         [-0.1084, -0.2690],
         ...,
         [-0.0179, -0.3904],
         [-0.0332, -0.2790],
         [ 0.0229, -0.1854]],

        [[-0.0142, -0.4707],
         [-0.0805, -0.2840],
         [-0.0953, -0.3215],
         ...,
         [-0.0783, -0.4438],
         [-0.0064, -0.3567],
         [ 0.0562, -0.2425]],

        [[ 0.0011, -0.4781],
       

In [14]:
model.base_model.classifier

Linear(in_features=768, out_features=2, bias=True)

In [9]:
config_train = {
    
    'dataset': dataset,
    'dataloaders':dataloaders, 
    'model': model, 
    'device': device, 
    'criterion':criterion, 
    'optimizer':optimizer, 
    'mode':mode, 
    'scheduler': scheduler,
    'epochs': epochs,
    'save_checkpoint': True,
    'checkpoint': None,
    'model_type': model_type
}

train_model(config_train)

Model will be saved at saved_models/bert-base-cased_cb_apt.pt
Epoch 1/10


Train:   0%|          | 0/15 [00:04<?, ?batch/s]


RuntimeError: Expected target size [16, 3], got [16]

In [10]:
train_data_object[5][0]['input_ids'][0]

tensor([    0,  7424,    47,   304, 23136,  3121,  1886,    23,   364,  3275,
         1075,  1992,     2,     2,   717,  3275,  1075, 10780,  1992,   480,
        39414,  1075, 10780,  1992,  4542,     5,  1139,     9, 39414,  1075,
           11, 15693,     4,    85,    16,  2034,   160, 21478,  1214,     8,
           16,   540,    87,    80,   728,   108,  1656,    31,     5,   755,
          852,     4,    85,    16,    45,    11,     5,   928, 17311,  3121,
         1886,  2056,  7328, 39414,  1075, 21144,    50, 32699,   225,  1908,
        15211,  4492,     4,    20,  1992,   745,    21,  4209,    11,  1125,
           73, 10684,    19,    10,    92,   745,    19, 10867,  1065,     5,
         1992,    36,  7048,   253,     9,  1566,   322,     2,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1, 