In [None]:
# Load the Drive helper and mount
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
folder_path = "/content/drive/MyDrive/EmbeddingsAttack/out/reddit_chunked/"

In [None]:
import torch
import pandas as pd
import pickle
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
BATCH_SIZE = 32

In [None]:
TRAIN_FRAC = 0.8  # use 10 percent of authors for training
TEST_FRAC = 0.2
folder_path = "/content/drive/MyDrive/EmbeddingsAttack/out/reddit_chunked/"
with open(folder_path + "reddit_train_embeddings_20240126_181755.pickle", "rb") as f:
    train_chunks = pickle.load(f)

with open(folder_path + "reddit_test_embeddings_20240126_181755.pickle", "rb") as f:
    test_chunks = pickle.load(f)


all_chunks = train_chunks + test_chunks
all_chunks = all_chunks
text2embedding = {
    elem["text"]: torch.tensor(elem["embedding"]) for elem in all_chunks
}
text2hash = {
    elem["text"]: hash(elem["text"]) for elem in all_chunks
}
hash2text = {
    a:b for b,a in text2hash.items()
}

hash2embedding = {
    text2hash[text]: text2embedding[text] for text in text2embedding
}

df_total = pd.DataFrame(all_chunks)[["metadata", "text"]]
df_total["author"] = df_total["metadata"].apply(lambda x: x['author'])
df = df_total

In [None]:
most_frequent_authors = df["author"].value_counts().index[:10]
df = df[df["author"].isin(most_frequent_authors)]

df["text_hash"] = df["text"].apply(lambda x: text2hash[x])
df = df[["author", "text_hash"]]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df["text_hash"] = df["text"].apply(lambda x: text2hash[x])


In [None]:
train_df = df.iloc[:int(len(df) * TRAIN_FRAC)]
test_df = df.iloc[int(len(df) * TRAIN_FRAC):]

In [None]:
# create dataloader for pairs of texts, such that the author can be the same or not the same and if the author is the same, the label is 1, otherwise 0

class TextPairDataset(torch.utils.data.Dataset):
    def __init__(self, df, text2hash, hash2embedding, duplicate_factor = 1):
        self.df = df
        self.text2hash = text2hash
        self.hash2embedding = hash2embedding
        self.texts = df["text_hash"].values
        self.authors = df["author"].values
        self.duplicate_factor = duplicate_factor

    def __len__(self):
        return len(self.texts)*self.duplicate_factor

    def __getitem__(self, idx):
        idx = idx // self.duplicate_factor
        sample_same = torch.randn(1).item() > 0
        if sample_same:
            author = self.authors[idx]
            same_author_idx = np.where(self.authors == author)[0]
            other_text_idx = torch.randint(0, len(same_author_idx), (1,)).item()
            other_text_idx = same_author_idx[other_text_idx]
        else:
            other_text_idx = torch.randint(0, len(self.texts), (1,)).item()
            iter = 0
            while other_text_idx == idx or self.authors[idx] == self.authors[other_text_idx]:
                other_text_idx = torch.randint(0, len(self.texts), (1,)).item()
                iter += 1
                if iter > 20:
                    break

        sample_same = self.authors[idx] == self.authors[other_text_idx]
        assert (self.authors[idx] == self.authors[other_text_idx]) == sample_same, f"author: {self.authors[idx]}, other author: {self.authors[other_text_idx]}, sample_same: {sample_same}"

        emb1 = self.hash2embedding[self.texts[idx]]
        emb2 = self.hash2embedding[self.texts[other_text_idx]]

        embedding_concat = torch.cat([emb1, emb2], dim=0)

        return embedding_concat, sample_same



In [None]:
ds = TextPairDataset(train_df, text2hash, hash2embedding)

In [None]:
VAL_FRAC = 0.2
val_df = train_df.iloc[:int(len(train_df) * VAL_FRAC)]
train_df = train_df.iloc[int(len(train_df) * VAL_FRAC):]

In [None]:
train_loader = torch.utils.data.DataLoader(TextPairDataset(train_df, text2hash, hash2embedding, duplicate_factor=100), batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(TextPairDataset(val_df, text2hash, hash2embedding, duplicate_factor=100), batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(TextPairDataset(test_df, text2hash, hash2embedding, duplicate_factor=100), batch_size=BATCH_SIZE, shuffle=True)

In [None]:
len(val_df)

1667

In [None]:
import torch.nn as nn

class One_Layer_MLP(torch.nn.Module):
  def __init__(self,  n_input_units):

    super(One_Layer_MLP, self).__init__()

    self.n_input_units = n_input_units
    self.fc1 = torch.nn.Linear(n_input_units, 1)

  def forward(self, x):

    x = self.fc1(x)
    return x


class Linear_skip_block(nn.Module):
  """
  Block of linear layer + softplus + skip connection +  dropout  + batchnorm
  """
  def __init__(self, n_input, dropout_rate):
    super(Linear_skip_block, self).__init__()

    self.fc = nn.Linear(n_input, n_input)
    self.act = torch.nn.LeakyReLU()

    self.bn = nn.BatchNorm1d(n_input, affine = True)
    self.drop = nn.Dropout(dropout_rate)

  def forward(self, x):
    x0 = x
    x = self.fc(x)
    x = self.act(x)
    x = x0 + x
    x = self.drop(x)
    x = self.bn(x)

    return x

class Linear_block(nn.Module):
  """
  Block of linear layer dropout  + batchnorm
  """
  def __init__(self, n_input, n_output, dropout_rate):
    super(Linear_block, self).__init__()

    self.fc = nn.Linear(n_input, n_output)
    self.act = torch.nn.LeakyReLU()
    self.bn = nn.BatchNorm1d(n_output, affine = True)
    self.drop = nn.Dropout(dropout_rate)

  def forward(self, x):
    x = self.fc(x)
    x = self.act(x)
    x = self.drop(x)
    x = self.bn(x)

    return x

class MLP(nn.Module):
  def __init__(self, n_input_units, n_hidden_units, n_skip_layers, dropout_rate):

    super(MLP, self).__init__()
    self.n_input_units = n_input_units
    self.n_hidden_units = n_hidden_units
    self.n_skip_layers = n_skip_layers
    self.dropout_rate = dropout_rate

    self.linear1 = Linear_block(n_input_units, n_hidden_units, dropout_rate)    # initial linear layer
    self.hidden_layers = torch.nn.Sequential(*[Linear_skip_block(n_hidden_units, dropout_rate) for _ in range(n_skip_layers)])  #hidden skip-layers

    self.linear_final =  torch.nn.Linear(n_hidden_units, 1)

  def forward(self, x):
    x = self.linear1(x)
    x = self.hidden_layers(x)
    x = self.linear_final(x)

    return(x)

In [None]:
import numpy as np
import time
from tqdm import tqdm
#Validation function

from sklearn.metrics import accuracy_score, f1_score, matthews_corrcoef, precision_score, recall_score

def validate(model, dataloader, loss_fun):
    val_loss_lis = []

    target_lis = []
    pred_lis = []



    model.eval()
    with torch.no_grad():
        for batch in dataloader:

          X, y = batch
          X = X.to(device)
          y = y.to(device)

          pred = model(X)
          pred = pred.squeeze(-1)
          loss = loss_fun(pred, y.float())
          val_loss_lis.append(loss.cpu().detach())

          target_lis.append(y.detach().cpu())
          pred_lis.append(pred.detach().cpu())

    mean_loss = np.mean(np.array(val_loss_lis))
    median_loss = np.median(np.array(val_loss_lis))

    target_ten, pred_ten = torch.cat(target_lis), torch.cat(pred_lis)

    pred_binary = (pred_ten > 0.5).float().cpu().detach().numpy()
    acc = accuracy_score(pred_binary, target_ten)
    f1 = f1_score(pred_binary, target_ten)
    mcc = matthews_corrcoef(pred_binary, target_ten)
    precision = precision_score(pred_binary, target_ten)
    recall = recall_score(pred_binary, target_ten)
    class_imabalance = np.mean(target_ten.cpu().detach().numpy())

    return mean_loss, median_loss, acc, f1, mcc, precision, recall, class_imabalance



# Training function
def train_loop(model, optimizer, loss_fun, trainset, valset, print_mod, device, n_epochs, save_path = None, early_stopping = True, n_epochs_early_stopping = 5):
    """
    train the model
    Args:
        model: The model to train
        optimizer: The used optimizer
        loss_fun: The used loss function
        trainset: The dataset to train on
        valset: The dataset to use for validation
        print_mod: Number of epochs to print result after
        device: Either "cpu" or "cuda"
        n_epochs: Number of epochs to train
        save_path: Path to save the model's state dict
        config: config file from the model to train
        sparse_ten (bool): if a sparse tensor is used for each batch
    """
    if early_stopping == True:
      n_early_stopping = n_epochs_early_stopping
      past_val_losses = []

    loss_lis = []
    target_lis = []
    pred_lis = []

    loss_lis_all = []
    val_loss_lis_all = []

    model = model.to(device)

    model.train()
    for epoch in range(n_epochs):
      start = time.time()
      for iter, batch in enumerate(tqdm(trainset)):

        X, y = batch
        X = X.to(device)
        y = y.to(device)

        pred = model(X)
        pred = pred.squeeze(-1)




        loss = loss_fun(pred, y.float())
        #print(loss)

        optimizer.zero_grad()       # clear previous gradients
        loss.backward()             # backprop

        optimizer.step()

        loss_lis.append(loss.cpu().detach())
        target_lis.append(y.detach().cpu())
        pred_lis.append(pred.detach().cpu())

      if epoch % print_mod == 0:

        end = time.time()
        time_delta = end - start

        mean_loss = np.mean(np.array(loss_lis))
        median_loss = np.median(np.array(loss_lis))

        target_ten, pred_ten = torch.cat(target_lis), torch.cat(pred_lis)

        pred_binary = (pred_ten > 0.5).float().cpu().detach().numpy()
        acc = accuracy_score(pred_binary, target_ten)
        f1 = f1_score(pred_binary, target_ten)
        mcc = matthews_corrcoef(pred_binary, target_ten)
        precision = precision_score(pred_binary, target_ten)
        recall = recall_score(pred_binary, target_ten)
        class_imabalance = np.mean(target_ten.cpu().detach().numpy())

        target_lis = []
        pred_lis = []



        loss_lis_all += loss_lis

        loss_lis = []

        mean_loss_val, median_loss_val, acc_val, f1_val, mcc_val, precision_val, recall_val, class_imabalance_val = validate(model, valset, loss_fun)

        val_loss_lis_all.append(mean_loss_val)



        print(f'Epoch nr {epoch}: mean_train_loss = {mean_loss}, median_train_loss = {median_loss}, train_acc = {acc}, train_f1 = {f1}, train_mcc = {mcc}, train_precision = {precision}, train_recall = {recall}, class_imbalance = {class_imabalance}, time = {time_delta}')
        print(f'Epoch nr {epoch}: mean_valid_loss = {mean_loss_val}, median_valid_loss = {median_loss_val}, valid_acc = {acc_val}, valid_f1 = {f1_val}, valid_mcc = {mcc_val}, valid_precision = {precision_val}, valid_recall = {recall_val}, class_imbalance = {class_imabalance_val}, time = {time_delta}')



        # early stopping based on median validation loss:
        if early_stopping:
          if len(past_val_losses) == 0 or mean_loss_val < min(past_val_losses):
            print("save model")
            torch.save(model.state_dict(), save_path)

          if len(past_val_losses) >= n_early_stopping:
            if mean_loss_val > max(past_val_losses):
              print(f"Early stopping because the median validation loss has not decreased since the last {n_early_stopping} epochs")
              return loss_lis_all, val_loss_lis_all
            else:
              past_val_losses = past_val_losses[1:] + [mean_loss_val]
          else:
            past_val_losses = past_val_losses + [mean_loss_val]



    return loss_lis_all, val_loss_lis_all

In [None]:
mlp1 = MLP(3072 , 768, 5, 0.3).to(device)

In [None]:
lr = 1e-4
model = mlp1
loss = torch.nn.BCEWithLogitsLoss()
save_path = "mlp1.pth"
opt = torch.optim.AdamW(model.parameters(), lr=lr)

In [None]:
r = train_loop(model = model,
           optimizer = opt,
           loss_fun = loss,
           trainset = train_loader,
           valset = val_loader,
           print_mod = 1,
           device = device,
           early_stopping = True,
           n_epochs_early_stopping = 5,
           save_path = "mlp1.pth",
              n_epochs = 100)

100%|██████████| 20838/20838 [03:10<00:00, 109.31it/s]


Epoch nr 0: mean_train_loss = 0.4904441833496094, median_train_loss = 0.47691333293914795, train_acc = 0.727128074385123, train_f1 = 0.6922557806376757, train_mcc = 0.46658802558300155, train_precision = 0.6134626337719956, train_recall = 0.7942720967510314, class_imbalance = 0.5002849430113977, time = 190.64577078819275
Epoch nr 0: mean_valid_loss = 0.6683225035667419, median_valid_loss = 0.6578793525695801, valid_acc = 0.6260947810437912, valid_f1 = 0.5572587404639797, valid_mcc = 0.26471968732426826, valid_precision = 0.4713247221387804, valid_recall = 0.6815157148565769, class_imbalance = 0.499250149970006, time = 190.64577078819275
save model


100%|██████████| 20838/20838 [03:04<00:00, 113.05it/s]


Epoch nr 1: mean_train_loss = 0.15463069081306458, median_train_loss = 0.1383957862854004, train_acc = 0.9330713857228554, train_f1 = 0.9326741754743795, train_mcc = 0.866207371235375, train_precision = 0.9269936185015474, train_recall = 0.9384247818774628, class_imbalance = 0.5000959808038392, time = 184.33452606201172
Epoch nr 1: mean_valid_loss = 1.8623181581497192, median_valid_loss = 1.836830735206604, valid_acc = 0.5802999400119976, valid_f1 = 0.4223484535742003, valid_mcc = 0.1943123822649192, valid_precision = 0.30592301987895604, valid_recall = 0.6818351460865856, class_imbalance = 0.5015356928614277, time = 184.33452606201172


100%|██████████| 20838/20838 [03:01<00:00, 114.59it/s]


Epoch nr 2: mean_train_loss = 0.055339887738227844, median_train_loss = 0.0342259407043457, train_acc = 0.9786592681463707, train_f1 = 0.978678325911453, train_mcc = 0.9573189700434434, train_precision = 0.9782005085948943, train_recall = 0.9791566102507983, class_imbalance = 0.5006913617276545, time = 181.85622692108154
Epoch nr 2: mean_valid_loss = 2.4933085441589355, median_valid_loss = 2.4781675338745117, valid_acc = 0.5835692861427715, valid_f1 = 0.42760906669744975, valid_mcc = 0.19748545606859846, valid_precision = 0.31181981071948245, valid_recall = 0.680184670269136, class_imbalance = 0.4988422315536893, time = 181.85622692108154


100%|██████████| 20838/20838 [03:02<00:00, 113.88it/s]


Epoch nr 3: mean_train_loss = 0.02567809261381626, median_train_loss = 0.009267398156225681, train_acc = 0.9899040191961608, train_f1 = 0.989850440236401, train_mcc = 0.9798111147782113, train_precision = 0.9885031858641581, train_recall = 0.9912013720303876, class_imbalance = 0.4980383923215357, time = 182.99867701530457
Epoch nr 3: mean_valid_loss = 2.73236083984375, median_valid_loss = 2.7040822505950928, valid_acc = 0.5795800839832034, valid_f1 = 0.4195461321848601, valid_mcc = 0.1897564803184181, valid_precision = 0.30426216904521647, valid_recall = 0.6754853851077448, class_imbalance = 0.4993641271745651, time = 182.99867701530457


100%|██████████| 20838/20838 [03:01<00:00, 114.96it/s]


Epoch nr 4: mean_train_loss = 0.01773754507303238, median_train_loss = 0.004068843089044094, train_acc = 0.9930773845230954, train_f1 = 0.9930867362985283, train_mcc = 0.986156164671564, train_precision = 0.9922485956191097, train_recall = 0.9939262941100642, class_imbalance = 0.5010992801439712, time = 181.27217817306519
Epoch nr 4: mean_valid_loss = 2.1499428749084473, median_valid_loss = 2.1217617988586426, valid_acc = 0.5925914817036593, valid_f1 = 0.4639657771568836, valid_mcc = 0.2112993831451685, valid_precision = 0.3525192798973338, valid_recall = 0.678454364987766, class_imbalance = 0.5001619676064787, time = 181.27217817306519


100%|██████████| 20838/20838 [03:01<00:00, 114.87it/s]


Epoch nr 5: mean_train_loss = 0.013855164870619774, median_train_loss = 0.002224269090220332, train_acc = 0.9943266346730654, train_f1 = 0.9943190265307317, train_mcc = 0.9886542924267023, train_precision = 0.9935983433124747, train_recall = 0.9950407559691264, class_imbalance = 0.49969256148770247, time = 181.41322231292725
Epoch nr 5: mean_valid_loss = 2.3861944675445557, median_valid_loss = 2.363478183746338, valid_acc = 0.5916076784643072, valid_f1 = 0.44873071784282764, valid_mcc = 0.21452788999602762, valid_precision = 0.33230592101318046, valid_recall = 0.6907314154659221, class_imbalance = 0.5001859628074385, time = 181.41322231292725


100%|██████████| 20838/20838 [03:02<00:00, 114.38it/s]


Epoch nr 6: mean_train_loss = 0.01233094371855259, median_train_loss = 0.0015489619690924883, train_acc = 0.9949400119976005, train_f1 = 0.9949290304739118, train_mcc = 0.9898809484136571, train_precision = 0.9942326928218053, train_recall = 0.9956263442073124, class_imbalance = 0.4992666466706659, time = 182.19059205055237
Epoch nr 6: mean_valid_loss = 2.636901617050171, median_valid_loss = 2.6124095916748047, valid_acc = 0.5781823635272946, valid_f1 = 0.44631847494862165, valid_mcc = 0.17814690943647435, valid_precision = 0.3398608945916777, valid_recall = 0.6498887844252333, class_imbalance = 0.500239952009598, time = 182.19059205055237


100%|██████████| 20838/20838 [03:02<00:00, 114.10it/s]


Epoch nr 7: mean_train_loss = 0.011124277487397194, median_train_loss = 0.001053161104209721, train_acc = 0.9953464307138572, train_f1 = 0.9953504819585814, train_mcc = 0.9906936254605698, train_precision = 0.9947319241205398, train_recall = 0.9959698095555096, class_imbalance = 0.500746850629874, time = 182.64097833633423
Epoch nr 7: mean_valid_loss = 2.292752265930176, median_valid_loss = 2.275287628173828, valid_acc = 0.5935152969406119, valid_f1 = 0.4608836095441924, valid_mcc = 0.21405993614073038, valid_precision = 0.3478824857671335, valid_recall = 0.6826141264641418, class_imbalance = 0.4994481103779244, time = 182.64097833633423


100%|██████████| 20838/20838 [03:01<00:00, 114.85it/s]


Epoch nr 8: mean_train_loss = 0.010614250786602497, median_train_loss = 0.0007939180359244347, train_acc = 0.9954484103179364, train_f1 = 0.9954497546465728, train_mcc = 0.9908978131853086, train_precision = 0.9947473137288979, train_recall = 0.9961531883227361, class_imbalance = 0.500500899820036, time = 181.45130348205566
Epoch nr 8: mean_valid_loss = 2.3988454341888428, median_valid_loss = 2.369483232498169, valid_acc = 0.5801499700059988, valid_f1 = 0.42995023497886414, valid_mcc = 0.18681647923016317, valid_precision = 0.31742252047479885, valid_recall = 0.6660778276888911, class_imbalance = 0.49880623875224955, time = 181.45130348205566


100%|██████████| 20838/20838 [03:00<00:00, 115.25it/s]


Epoch nr 9: mean_train_loss = 0.009810848161578178, median_train_loss = 0.0006672918098047376, train_acc = 0.9958908218356328, train_f1 = 0.9958857252685539, train_mcc = 0.9917824593626013, train_precision = 0.9952431257915618, train_recall = 0.9965291550971836, class_imbalance = 0.4997030593881224, time = 180.82779335975647
Epoch nr 9: mean_valid_loss = 3.2928004264831543, median_valid_loss = 3.261770725250244, valid_acc = 0.5731493701259748, valid_f1 = 0.44470977509325593, valid_mcc = 0.16502487138421407, valid_precision = 0.3418435291718156, valid_recall = 0.6361322587126879, class_imbalance = 0.50000599880024, time = 180.82779335975647
Early stopping because the median validation loss has not decreased since the last 5 epochs


In [None]:
model.load_state_dict(torch.load(save_path))

<All keys matched successfully>

In [None]:
model = model.eval().to(device)

validate(model, test_loader, loss)

(0.6250165,
 0.6171717,
 0.6445201535508637,
 0.5661427098950524,
 0.3089414812243957,
 0.46480877785150354,
 0.7239788505609395,
 0.49898752399232243)