In [1]:
from transformers import DistilBertTokenizerFast, AutoModel, DistilBertModel
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F
import pandas as pd
from tqdm import tqdm
import copy

# code is based on repo: https://github.com/xszheng2020/memorization
# data is from https://github.com/xszheng2020/memorization/tree/master/sst/data

train = pd.read_csv("train.csv")
val = pd.read_csv("dev.csv")
test = pd.read_csv("test.csv")
num_classes = train["label"].nunique()
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")


class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        row = self.data.iloc[index]
        label = row[0]
        sentence = row[1]
        return label, sentence

    def __len__(self):
        return len(self.data)


class CustomDatasetWithMask(Dataset):
    def __init__(self, data, mask=None):
        self.data = data
        self.mask = mask

    def __getitem__(self, index):
        row = self.data.iloc[index]
        label = row[0]
        sentence = row[1]
        sample_index = row[2]
        mask = self.mask[sample_index]
        return label, sentence, mask

    def __len__(self):
        return len(self.data)

In [2]:
import random
import glob
import re
import pickle

BATCH_SIZE = 100
EPOCHS = 1000
early_stop_steps = 30


def collate_fn(batch):
    label = [b[0] for b in batch]
    text_a = [b[1] for b in batch]
    my_dict = tokenizer(
        text_a,
        None,
        add_special_tokens=True,
        padding=True,
        truncation=True,
        max_length=512,
        return_attention_mask=True,
        return_tensors="pt",
    )
    label = torch.tensor(label)
    return label, my_dict["input_ids"], my_dict["attention_mask"]


# linear model
class CustomModel_nofreeze(nn.Module):
    def __init__(
        self,
    ):
        super(CustomModel_nofreeze, self).__init__()
        self.bert = DistilBertModel.from_pretrained(
            "distilbert-base-uncased",
            output_hidden_states=False,
            output_attentions=False,
        )
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, inputs, is_ids, attention_mask):
        if is_ids:
            last_hidden = self.bert(input_ids=inputs, attention_mask=attention_mask)[0]
        else:
            last_hidden = self.bert(
                inputs_embeds=inputs, attention_mask=attention_mask
            )[0]
        ####
        cls_embedding = last_hidden[:, 0, :]  # (bs, dim) pooled_output = cls_embedding
        ####
        logits = self.classifier(cls_embedding)  # (bs, num_labels)
        ####
        return cls_embedding, logits


# model with freezed first 3 layers (DNN (3 layers))
class CustomModel(nn.Module):
    def __init__(
        self,
    ):
        super(CustomModel, self).__init__()

        self.bert = DistilBertModel.from_pretrained(
            "distilbert-base-uncased",
            output_hidden_states=False,
            output_attentions=False,
        )

        params_to_freeze = [
            "bert.embeddings.",
            "bert.transformer.layer.0.",
            "bert.transformer.layer.1.",
            "bert.transformer.layer.2.",
            "bert.transformer.layer.3.",
        ]
        for name, param in self.named_parameters():
            # if "classifier" not in name:  # classifier layer
            #     param.requires_grad = False

            if any(pfreeze in name for pfreeze in params_to_freeze):
                param.requires_grad = False

        self.classifier = nn.Linear(768, num_classes)

    def forward(self, inputs, is_ids, attention_mask):
        if is_ids:
            last_hidden = self.bert(input_ids=inputs, attention_mask=attention_mask)[0]
        else:
            last_hidden = self.bert(
                inputs_embeds=inputs, attention_mask=attention_mask
            )[0]
        ####
        cls_embedding = last_hidden[:, 0, :]  # (bs, dim) pooled_output = cls_embedding
        ####
        logits = self.classifier(cls_embedding)  # (bs, num_labels)
        ####

        return cls_embedding, logits


# model without freeze (DNN)
class CustomModel_allfreeze(nn.Module):
    def __init__(
        self,
    ):
        super(CustomModel_allfreeze, self).__init__()
        self.bert = DistilBertModel.from_pretrained(
            "distilbert-base-uncased",
            output_hidden_states=False,
            output_attentions=False,
        )
        for param in self.bert.parameters():
            param.requires_grad = False
        self.classifier = nn.Linear(768, num_classes)

    def forward(self, inputs, is_ids, attention_mask):
        if is_ids:
            last_hidden = self.bert(input_ids=inputs, attention_mask=attention_mask)[0]
        else:
            last_hidden = self.bert(
                inputs_embeds=inputs, attention_mask=attention_mask
            )[0]
        ####
        cls_embedding = last_hidden[:, 0, :]  # (bs, dim) pooled_output = cls_embedding
        ####
        logits = self.classifier(cls_embedding)  # (bs, num_labels)
        ####
        return cls_embedding, logits


### Train models
for ind, SEED in enumerate([321, 123, 456, 654, 345, 789, 876, 907, 697]):
    torch.manual_seed(SEED)
    random.seed(SEED)
    np.random.seed(SEED)
    results = []
    ### finetune models for each datasets with top m% removed
    for filename in glob.glob("data/*.csv"):
        from_last_step = 0
        perc = int(re.findall(r"\d+", filename)[0])
        print(perc)
        train = pd.read_csv(filename)
        train_dataset = CustomDataset(data=train)
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=BATCH_SIZE,
            num_workers=0,
            shuffle=False,
            collate_fn=collate_fn,
        )

        val_dataset = CustomDataset(data=val)
        val_dataloader = DataLoader(
            val_dataset,
            batch_size=BATCH_SIZE,
            num_workers=0,
            shuffle=False,
            collate_fn=collate_fn,
        )

        test_dataset = CustomDataset(data=test)
        test_dataloader = DataLoader(
            test_dataset,
            batch_size=BATCH_SIZE,
            num_workers=0,
            shuffle=False,
            collate_fn=collate_fn,
        )

        ###DNN
        model = CustomModel_nofreeze()
        model.cuda()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)

        best_val_acc = 0
        best_state_dict = None

        for i in range(EPOCHS):
            model.train()
            train_correct = 0
            total_train_loss = 0
            total_val_loss = 0
            val_correct = 0

            for batch in train_dataloader:
                ####
                model.zero_grad()
                optimizer.zero_grad()
                ####
                label = batch[0].cuda()
                input_ids = batch[1].cuda()
                attention_mask = batch[2].cuda()

                _, outputs = model(input_ids, True, attention_mask=attention_mask)

                loss = F.cross_entropy(outputs, label)
                loss.backward()
                optimizer.step()
                total_train_loss += loss.item()
                train_correct += (torch.argmax(outputs, axis=-1) == label).sum().item()

            model.eval()
            with torch.no_grad():
                for batch in val_dataloader:
                    label = batch[0].cuda()
                    input_ids = batch[1].cuda()
                    attention_mask = batch[2].cuda()

                    _, outputs = model(input_ids, True, attention_mask=attention_mask)
                    total_val_loss += loss.item()
                    val_correct += (
                        (torch.argmax(outputs, axis=-1) == label).sum().item()
                    )

            val_acc = val_correct / len(val)
            if val_acc >= best_val_acc:
                best_val_acc = val_acc
                best_state_dict = copy.deepcopy(model.state_dict())
                from_last_step = 0

            if from_last_step >= early_stop_steps:
                break
            from_last_step += 1

            if (i % 30 == 0) or (i == EPOCHS - 1):
                print(
                    i,
                    "train_acc:",
                    round(train_correct / len(train), 6),
                    "train_loss:",
                    round(total_train_loss / len(train), 6),
                    "val_loss:",
                    round(total_val_loss / len(val), 6),
                    "val_acc:",
                    round(val_correct / len(val), 6),
                )
        # load best chechpoint
        model.load_state_dict(best_state_dict)
        # test accuracy
        model.eval()
        test_correct = 0
        with torch.no_grad():
            for batch in test_dataloader:
                label = batch[0].cuda()
                input_ids = batch[1].cuda()
                attention_mask = batch[2].cuda()

                _, outputs = model(input_ids, True, attention_mask=attention_mask)
                total_val_loss += loss.item()
                test_correct += (torch.argmax(outputs, axis=-1) == label).sum().item()
        test_acc = test_correct / len(test)

        val_res_acc_nofreeze = best_val_acc
        test_res_acc_nofreeze = test_acc
        print(best_val_acc, test_acc)

        # DNN (3 layers)
        from_last_step = 0
        model = CustomModel()
        model.cuda()
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters() if (p.requires_grad == True)
                ]
            },
        ]
        optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=0.00001)

        best_val_acc = 0
        best_state_dict = None

        for i in range(EPOCHS):
            model.train()
            train_correct = 0
            total_train_loss = 0
            total_val_loss = 0
            val_correct = 0

            for batch in train_dataloader:
                ####
                model.zero_grad()
                optimizer.zero_grad()
                ####
                label = batch[0].cuda()
                input_ids = batch[1].cuda()
                attention_mask = batch[2].cuda()

                _, outputs = model(input_ids, True, attention_mask=attention_mask)

                loss = F.cross_entropy(outputs, label)
                loss.backward()
                optimizer.step()
                total_train_loss += loss.item()
                train_correct += (torch.argmax(outputs, axis=-1) == label).sum().item()

            model.eval()
            with torch.no_grad():
                for batch in val_dataloader:
                    label = batch[0].cuda()
                    input_ids = batch[1].cuda()
                    attention_mask = batch[2].cuda()

                    _, outputs = model(input_ids, True, attention_mask=attention_mask)
                    total_val_loss += loss.item()
                    val_correct += (
                        (torch.argmax(outputs, axis=-1) == label).sum().item()
                    )

            val_acc = val_correct / len(val)
            if val_acc >= best_val_acc:
                best_val_acc = val_acc
                best_state_dict = copy.deepcopy(model.state_dict())
                from_last_step = 0

            if from_last_step >= early_stop_steps:
                break
            from_last_step += 1

            if (i % 30 == 0) or (i == EPOCHS - 1):
                print(
                    i,
                    "train_acc:",
                    round(train_correct / len(train), 6),
                    "train_loss:",
                    round(total_train_loss / len(train), 6),
                    "val_loss:",
                    round(total_val_loss / len(val), 6),
                    "val_acc:",
                    round(val_correct / len(val), 6),
                )
        # load best chechpoint
        model.load_state_dict(best_state_dict)
        # test accuracy
        model.eval()
        test_correct = 0
        with torch.no_grad():
            for batch in test_dataloader:
                label = batch[0].cuda()
                input_ids = batch[1].cuda()
                attention_mask = batch[2].cuda()

                _, outputs = model(input_ids, True, attention_mask=attention_mask)
                total_val_loss += loss.item()
                test_correct += (torch.argmax(outputs, axis=-1) == label).sum().item()
        test_acc = test_correct / len(test)

        val_res_acc = best_val_acc
        test_res_acc = test_acc
        print(best_val_acc, test_acc)

        ### Linear
        from_last_step = 0
        model = CustomModel_allfreeze()
        model.cuda()
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in model.named_parameters() if (p.requires_grad == True)
                ]
            },
        ]
        optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=0.0001)

        best_val_acc = 0
        best_state_dict = None

        for i in range(EPOCHS):
            model.train()
            train_correct = 0
            total_train_loss = 0
            total_val_loss = 0
            val_correct = 0

            for batch in train_dataloader:
                ####
                model.zero_grad()
                optimizer.zero_grad()
                ####
                label = batch[0].cuda()
                input_ids = batch[1].cuda()
                attention_mask = batch[2].cuda()

                _, outputs = model(input_ids, True, attention_mask=attention_mask)

                loss = F.cross_entropy(outputs, label)
                loss.backward()
                optimizer.step()
                total_train_loss += loss.item()
                train_correct += (torch.argmax(outputs, axis=-1) == label).sum().item()

            model.eval()
            with torch.no_grad():
                for batch in val_dataloader:
                    label = batch[0].cuda()
                    input_ids = batch[1].cuda()
                    attention_mask = batch[2].cuda()

                    _, outputs = model(input_ids, True, attention_mask=attention_mask)
                    total_val_loss += loss.item()
                    val_correct += (
                        (torch.argmax(outputs, axis=-1) == label).sum().item()
                    )

            val_acc = val_correct / len(val)
            if val_acc >= best_val_acc:
                best_val_acc = val_acc
                best_state_dict = copy.deepcopy(model.state_dict())
                from_last_step = 0

            if from_last_step >= early_stop_steps:
                break

            from_last_step += 1

            if (i % 30 == 0) or (i == EPOCHS - 1):
                print(
                    i,
                    "train_acc:",
                    round(train_correct / len(train), 6),
                    "train_loss:",
                    round(total_train_loss / len(train), 6),
                    "val_loss:",
                    round(total_val_loss / len(val), 6),
                    "val_acc:",
                    round(val_correct / len(val), 6),
                )
        # load best chechpoint
        model.load_state_dict(best_state_dict)
        # test accuracy
        model.eval()
        test_correct = 0
        with torch.no_grad():
            for batch in test_dataloader:
                label = batch[0].cuda()
                input_ids = batch[1].cuda()
                attention_mask = batch[2].cuda()

                _, outputs = model(input_ids, True, attention_mask=attention_mask)
                total_val_loss += loss.item()
                test_correct += (torch.argmax(outputs, axis=-1) == label).sum().item()
        test_acc = test_correct / len(test)

        print(best_val_acc, test_acc)
        val_res_allfreeze_acc = best_val_acc
        test_res_allfreeze_acc = test_acc

        results.append(
            [
                perc,
                val_res_acc_nofreeze,
                test_res_acc_nofreeze,
                val_res_acc,
                test_res_acc,
                val_res_allfreeze_acc,
                test_res_allfreeze_acc,
            ]
        )
    df = pd.DataFrame(results)
    df.columns = [
        "perc",
        "full",
        "full_test",
        "3l",
        "3l_test",
        "only_classifier",
        "only_classifier_test",
    ]
    df = df.sort_values("perc")
    # save results for each SEED
    with open(f"df_{ind}.pkl", "wb") as f:
        pickle.dump(df, f)