In [4]:
import torch
import torch.nn.functional as F
from src.methods import SophiaG
from lion_pytorch import Lion


ModuleNotFoundError: No module named 'lion_pytorch'

In [None]:
import wget
import os

print('Downloading dataset')
url = 'https://nyu-mll.github.io/CoLA/cola_public_1.1.zip'

if not os.path.exists('./cola_public_1.1.zip'):
  wget.download(url, './cola_public_1.1.zip')
  if not os.path.exists('./cola_public/'):
    !unzip cola_public_1.1.zip

In [None]:
import pandas as pd
df = pd.read_csv("./cola_public/raw/in_domain_train.tsv", delimiter='\t',
                 header=None, names=['sentence_source', 'label', 'label_notes', 'sentence'])
df

In [None]:
from transformers import BertTokenizer, BertForSequenceClassification

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

In [None]:
input_ids = []
for sent in df.sentence.values:
    encoded_sent = tokenizer.encode(
                        sent, 
                        add_special_tokens = True, 
                   )
    input_ids.append(encoded_sent)
print('Original: ', df.sentence.values[0])
print('Token IDs:', input_ids[0])

In [None]:
print('Max sentence length: ', max([len(sen) for sen in input_ids]))

In [None]:
def pad_sequences(seqs, maxlen=None, value=0, padding="post"):
    if maxlen is None:
        raise ValueError("Invalid maxlen: {}".format(maxlen))
    for i in range(len(seqs)):
        add = [value] * max(0, maxlen - len(seqs[i]))
        if padding == "post":
            seqs[i] = seqs[i] + add
        elif padding == "pre":
            seqs[i] = add + seqs[i]
    return seqs

MAX_LEN =  max([len(sen) for sen in input_ids])+1
print('\nPadding/truncating all sentences to %d values...' % MAX_LEN)
print('\nPadding token: "{:}", ID: {:}'.format(tokenizer.pad_token, tokenizer.pad_token_id))

input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, value=0, padding="post")
print('\Done.')

In [None]:
attention_masks = []
for sent in input_ids:
    att_mask = [int(token_id > 0) for token_id in sent]
    attention_masks.append(att_mask)

In [None]:
import matplotlib.pyplot as plt
plt.pcolor(attention_masks[1000:2000])

In [None]:
df.label.values

In [None]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import matthews_corrcoef

data_split_rs = 49


train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(input_ids,df.label.values,
                                                            random_state=data_split_rs, test_size=0.1)

train_masks, validation_masks, _, _ = train_test_split(attention_masks, df.label.values,
                                             random_state=data_split_rs, test_size=0.1)

In [None]:
from src.configs import set_one_device

device = set_one_device(0)

In [None]:
train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(validation_inputs)
train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)
train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)

In [None]:
import numpy as np

random_seed = 4
num_workers = 1

from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
# ###full_grad_dataloader
# batch_size = 32 # train_inputs.shape[0]
# # Create the DataLoader for our training set.
# train_data = TensorDataset(train_inputs, train_masks, train_labels)
# train_sampler = RandomSampler(train_data)
# train_loader_for_full_grad = DataLoader(
#     train_data, batch_size=batch_size, num_workers=num_workers,
#     worker_init_fn = lambda id: np.random.seed(id + num_workers * random_seed)
# )


### for training
batch_size = 32

train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_loader = DataLoader(
    train_data, batch_size=batch_size, shuffle=True, num_workers=num_workers,
    worker_init_fn = lambda id: np.random.seed(id + num_workers * random_seed)
)

validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
val_loader = DataLoader(
    validation_data, batch_size=batch_size, shuffle=True, num_workers=num_workers,
    worker_init_fn = lambda id: np.random.seed(id + num_workers * random_seed)
)

In [None]:
nets = []

history_random_seed = 74

for starting_point_random_seed in [73, 74]:
    for _ in range(4):
        torch.manual_seed(starting_point_random_seed)
        nets.append(
            BertForSequenceClassification.from_pretrained(
                "bert-base-uncased",
                num_labels = 2,
                output_attentions = False,
                output_hidden_states = False,
            )
        )
        nets[-1].zero_grad()
        for i, name_param in enumerate(nets[-1].named_parameters()):
            if i < 197: # 197 -- 2 linears, 199 - only last linear
                name_param[1].requires_grad = False
        nets[-1].train()

torch.manual_seed(history_random_seed)

In [None]:
opts = []

for i in range(2):
    opts += [
        torch.optim.Adam(nets[i * 4 + 0].parameters(), lr = 6e-4, eps = 1e-8, weight_decay=0.0005),
        torch.optim.SGD(nets[i * 4 + 1].parameters(), lr=6e-4, momentum=0.9),
        SophiaG(nets[i * 4 + 2].parameters(),  lr=6e-4, betas=(0.965, 0.99), rho=0.01, weight_decay=1e-1),
        Lion(nets[i * 4 + 3].parameters(), lr=1e-4, weight_decay=1e-2)
    ]

opt_names = [ 
    "Adam, 6e-4",
    "SGD, 6e-4",
    "Sophia-G, 6e-4",
    "Lion, 1e-4",
] * 2

In [None]:
bs_muls = [
    1, 1,
    1, 1
] * 2

In [None]:
hist = []

for (net, optimizer, opt_name, bs_mul) in zip(nets, opts, opt_names, bs_muls):
    hist.append({
        "task_name": "BERT on CoLA",
        "name": opt_name,
        "bs_mul": bs_mul,
        "train_loss": [], "train_x": [],
        "val_loss": [], "val_x": [],
        "train_acc_top_1": [], "train_acc_top_5": [],
        "val_acc_top_1": [], "val_acc_top_5": [],
        "epochs_x": [],
        "total_steps": 0,
        "prev_val_eval_step": 0,
        "batch_end": True
    })

In [None]:
for optimizer in opts:
    for state in optimizer.state.values():
        for k, v in state.items():
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)

In [None]:
def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred)).contiguous()
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

In [None]:
def recursive_to(param, device):
    if isinstance(param, torch.Tensor):
        param.data = param.data.to(device)
        if param._grad is not None:
            param._grad.data = param._grad.data.to(device)
    elif isinstance(param, dict):
        for subparam in param.values():
            recursive_to(subparam, device)
    elif isinstance(param, list):
        for subparam in param:
            recursive_to(subparam, device)

def optimizer_to(optim, device):
    for param_group in optim.param_groups:
        for param in param_group.values():
            recursive_to(param, device)

**sophia loop**

In [None]:
total_bs = len(train_loader)
block_size = 1024
bs = total_bs * block_size # 5120 iby default in original implementation
# iter_num = -1
default_loss_arr, sampled_loss_arr = [], []

nets[0].to(device)
for epoch in range(10):
    for iter_num, data in enumerate(train_loader, 0):
        inputs, masks, labels = data[0].to(device), data[1].to(device), data[2].to(device)
        outputs = net(
            inputs,
            token_type_ids=None,
            attention_mask=masks,
            labels=labels
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step(bs=bs)
        optimizer.zero_grad(set_to_none=True)
        default_loss_arr.append(loss.item())

        if iter_num % 10 != 9:
            continue
        else:
            logits = outputs.logits
            samp_dist = torch.distributions.Categorical(logits=logits)
            y_sample = samp_dist.sample()
            loss_sampled = F.cross_entropy(logits.view(-1, logits.size(-1)), y_sample.view(-1), ignore_index=-1)
            loss_sampled.backward()
            optimizer.update_hessian()
            optimizer.zero_grad(set_to_none=True)
            sampled_loss_arr.append(loss_sampled.item())
            nets[0].zero_grad()

**general loop**

In [None]:
from IPython import display

batch_mul_step_count = 400
calc_norm_diffs = True

for epoch in range(10):
    for (net, optimizer, net_hist) in zip(nets, opts, hist):
        net.to(device)
        optimizer_to(optimizer, device)

        total_steps = net_hist["total_steps"]
        bs_mul = net_hist["bs_mul"]

        if net_hist["bs_mul"] == "linear":
            if not ("bs_mul_value" in net_hist):
                net_hist["bs_mul_value"] = 1

            bs_mul = net_hist["bs_mul_value"]

        net_hist["epochs_x"].append(total_steps)

        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
                net.train()

                net_hist["batch_end"] = False
                inputs, masks, labels = data[0].to(device), data[1].to(device), data[2].to(device)

                outputs = net(
                    inputs,
                    token_type_ids=None,
                    attention_mask=masks,
                    labels=labels
                )
                loss = outputs[0] / bs_mul
                loss.backward()

                if total_steps % bs_mul == bs_mul - 1:
                    optimizer.step()
                    optimizer.zero_grad()
                    net_hist["batch_end"] = True

                net_hist["train_loss"].append(loss.detach().cpu().item() * bs_mul)
                net_hist["train_x"].append(total_steps)

                if total_steps % bs_mul == bs_mul - 1:
                    if net_hist["bs_mul"] == "linear":
                        net_hist["bs_mul_value"] = int(int(total_steps) / batch_mul_step_count) + 1
                        bs_mul = net_hist["bs_mul_value"]

                top_1 = accuracy(outputs.logits, labels.data, topk=(1,))
                net_hist["train_acc_top_1"].append(top_1[0].detach().cpu().item())

                # evaluate on validation dataset
                prev_val_eval_step = net_hist["prev_val_eval_step"]
                if (total_steps - prev_val_eval_step) > 20 and net_hist["batch_end"]:
                    net_hist["prev_val_eval_step"] = total_steps

                    net.eval()

                    val_losses = []
                    val_accs = []

                    with torch.no_grad():
                        for step, val_data in enumerate(val_loader):

                            inputs, masks, labels = val_data[0].to(device), val_data[1].to(device), val_data[2].to(device)

                            outputs = net(
                                inputs,
                                token_type_ids=None,
                                attention_mask=masks,
                                labels=labels
                            )
                            loss = outputs[0]

                            val_losses.append(loss.detach().cpu().item())

                            acc = accuracy(outputs.logits, labels.data, topk=(1,))
                            val_accs.append(acc[0].detach().cpu().item())

                    net_hist["val_loss"].append(np.mean(val_losses))
                    net_hist["val_x"].append(total_steps)
                    net_hist["val_acc_top_1"].append(np.mean(val_accs))

                    net.train()

                if total_steps % 100 == 0:
                    display.clear_output(wait=True)

                    grouped_hist = group_uniques_full(hist, ["train_loss", "val_loss", "val_acc_top_1", "train_acc_top_1"])

                    fig = plt.figure(figsize=(15, 8 + 2 * ((len(grouped_hist) + 2) // 3)))
                    gs = GridSpec(4 + 2 * ((len(grouped_hist) + 2) // 3),3, figure=fig)

                    ax1 = fig.add_subplot(gs[0:4,:2])
                    ax2 = fig.add_subplot(gs[0:2,2])
                    ax3 = fig.add_subplot(gs[2:4,2])

                    ax1 = make_loss_plot(ax1, grouped_hist, eps=0.01, make_val=False, alpha=0.9)
                    ax2 = make_accuracy_plot(ax2, grouped_hist, eps=0.01, make_train=True, make_val=False, top_k=1, alpha=0.9)
                    ax3 = make_accuracy_plot(ax3, grouped_hist, eps=0.01, make_train=False, make_val=True, top_k=1, alpha=0.9)

                    if calc_norm_diffs == True:
                        draw_norm_hists_for_different_models(fig, gs[4:,:], grouped_hist, bins_n=100, draw_normal=True)

                    gs.tight_layout(fig)
                    plt.draw()
                    plt.show()

                total_steps += 1
        net_hist["total_steps"] = total_steps
        net.to("cpu")
        optimizer_to(optimizer, "cpu")

print('Finished Training')