## Load Library and Data 

In [24]:
import numpy as np
import pandas as pd
import matplotlib as plt
from sklearn.model_selection import train_test_split
import torch
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
import torch.nn.functional as F
from sklearn.feature_extraction.text import CountVectorizer, TfidfVectorizer

In [2]:
# load datasets
domain1 = pd.read_json('data/domain1_train.json', lines=True)
domain2 = pd.read_json('data/domain2_train.json', lines=True)
test = pd.read_json('data/test_set.json', lines=True)

In [3]:
def rated_sample(df, label_1_rate):

    sampled_label_0 = df[df['label'] == 0]
    label_0_count = len(df[df['label'] == 0])

    #calculate the number of label 1 samples
    sampled_label_1_count = int(label_1_rate * (label_0_count / (1 - label_1_rate)))
    
    # random sample with replacement from label 1
    sampled_label_1 = df[df['label'] == 1].sample(n=sampled_label_1_count, replace=True)

    # concatenate the sampled label 0 and sampled label 1
    sampled_df = pd.concat([sampled_label_0, sampled_label_1])
    
    return sampled_df

In [4]:
newdomain2 = rated_sample(domain2,0.5)
newdomain2 = newdomain2.sort_index()

Padding: fill 5000 at back of each instance to ensure they have the same length

In [5]:
def pad(df):
    # Convert dataframe to lists of texts
    text_list = df['text'].tolist()
    # Convert lists of texts into a list of tensors
    tensors_df = [torch.tensor(text) for text in text_list]
    # pad with a new token: 5000
    pad_df  = pad_sequence(tensors_df,batch_first=True, padding_value=5000)
    return pad_df

In [6]:
pad_dm1 = pad(domain1)
pad_dm2 = pad(newdomain2)
print(pad_dm1.shape, pad_dm2.shape)

torch.Size([19500, 238]) torch.Size([25500, 1075])


In [79]:
length = len(pad_dm1[0])
embed_dim = 5005

In [74]:
token_list = [i for i in range(5005)]

In [75]:
len(token_list)

5005

In [76]:
token_list.index(5000)

5000

In [80]:
# (length = 238, batch_size = 1, embedding_dim = 5001)
tensor_dm1 = torch.zeros(length,1,embed_dim)
for text in pad_dm1:
    for li, token in enumerate(text):
        tensor_dm1[li][0][token_list.index(token.item())]

In [81]:
tensor_dm1.shape

torch.Size([238, 1, 5005])

In [103]:
input_dim = length
embed_dim = embed_dim
num_heads = 5
dropout = 0.1
# Create an instance of nn.MultiheadAttention
multihead_attn = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads,dropout=dropout,batch_first=False)

In [100]:
class TextClassifier(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout):
        super(TextClassifier, self).__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=False)
        self.fc = nn.Linear(embed_dim, 2)  # Binary classification output

    def forward(self, x):
        attn_output, _ = self.multihead_attn(x, x, x)
        output = self.fc(attn_output.mean(dim=1))  # Mean pooling over the sequence dimension
        output = F.log_softmax(output, dim=1)
        return F.softmax(output, dim=-1)

# Create an instance of the TextClassifier
classifier = TextClassifier(embed_dim, num_heads, dropout)

In [101]:
predictions = classifier(tensor_dm1)

In [99]:
predictions

tensor([[1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
      