In [502]:
import pickle
import numpy as np
from datasets import load_dataset
import tqdm
import math
import sys

In [503]:
# # GLUE
# dataset = load_dataset("glue", "sst2")
# val_dataset = dataset["validation"]

dataset = load_dataset("rotten_tomatoes")
train_dataset = dataset["train"]
test_dataset = dataset["test"]

### CNN Based Classifier

In [602]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [603]:
# Define the CNN architecture for 1D vectors
class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
#         self.conv1 = nn.Conv1d(in_channels = 1, out_channels = 32, kernel_size = 3)
#         self.conv2 = nn.Conv1d(in_channels = 32, out_channels = 32, kernel_size = 3)
#         self.conv3 = nn.Conv1d(in_channels = 32, out_channels = 16, kernel_size = 3)
#         self.pool = nn.MaxPool1d(kernel_size = 2)
        self.fc1 = nn.Linear(768, 4)  # Adjust input size according to your data
        self.fc2 = nn.Linear(4, 1)      # Output layer with 1 neuron for binary classification

    def forward(self, x):
#         x = self.pool(nn.functional.relu(self.conv1(x)))
#         x = self.pool(nn.functional.relu(self.conv2(x)))
#         x = self.pool(nn.functional.relu(self.conv3(x)))
#         x = x.view(-1, 16 * 94)  # Adjust the size according to the output size of the last convolutional layer
        x = nn.functional.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))  # Sigmoid activation for binary classification
        return x

In [604]:
# Sample data loader for demonstration
class SampleDataset(torch.utils.data.Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [605]:
# Get Training Data

# Read pkl file 
with open('embeddings/rotten-MPNET/rotten-train-768.pkl', 'rb') as f:
    vectors = pickle.load(f)
    
X = np.array(vectors)

labels = []
for item in train_dataset:
    labels.append(item['label'])
    
Y = np.array(labels)

### Determining which Subset 

In [667]:
# Read Instance Scores
with open('scores/rotten-MPNET/instance-scores-rotten-L16-30-40-30.pkl', 'rb') as f:
    instance_scores = pickle.load(f)
    
# Sort indices based on the score in descending order
def sort_indices_by_values(values):
    return sorted(range(len(values)), key=lambda i: np.abs(values[i]), reverse = True)

# Example usage:
values = [3, 1, 4, 1, 5, 9, 2, 6, 5, 3]
sorted_indices = sort_indices_by_values(instance_scores)

In [761]:
# Compute subset length

percentage = 0.75
name = 'models/rotten/model-L16-75.pth' 

subset_len = int(percentage * len(X))
sorted_subset_indices = sorted_indices[:subset_len]

In [762]:
X_eff = []
Y_eff = []

for i in sorted_subset_indices:
    X_eff.append(X[i])
    Y_eff.append(Y[i])

In [763]:
N = len(X_eff)
X_train = torch.from_numpy(np.array(X_eff).reshape(N, 1, 768))
Y_train = torch.from_numpy(np.array(Y_eff).reshape(N, 1, 1)).float()
print(str(N) + ' samples selected')

6397 samples selected


In [764]:
c0 = 0
c1 = 0
for i in Y_eff:
    if i == 0:
        c0 += 1
    else:
        c1 += 1
        
print('Class 0: ' + str(c0))
print('Class 1: ' + str(c1))

Class 0: 3190
Class 1: 3207


### Start Training

In [771]:
# Initialize the model, loss function, and optimizer
model = CNNClassifier()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr = 0.00005)

# Create data loader
trainset = SampleDataset(X_train, Y_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 32, shuffle=True)

# Train the model
for epoch in tqdm.tqdm(range(100)):  # Adjust number of epochs as needed
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 9:  # Print every 10 mini-batches
            print('[%d, %5d] loss: %.4f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0

print('Finished Training')

  2%|▏         | 2/100 [00:00<00:07, 12.31it/s]

[1,    10] loss: 0.0694
[1,   110] loss: 0.7000
[2,    10] loss: 0.0688
[2,   110] loss: 0.6945
[3,    10] loss: 0.0682
[3,   110] loss: 0.6792
[4,    10] loss: 0.0679
[4,   110] loss: 0.6705


  5%|▌         | 5/100 [00:00<00:13,  7.25it/s]

[5,    10] loss: 0.0665
[5,   110] loss: 0.6569
[6,    10] loss: 0.0654


  6%|▌         | 6/100 [00:00<00:13,  7.02it/s]

[6,   110] loss: 0.6446
[7,    10] loss: 0.0638
[7,   110] loss: 0.6343
[8,    10] loss: 0.0623
[8,   110] loss: 0.6217


  9%|▉         | 9/100 [00:01<00:10,  8.74it/s]

[9,    10] loss: 0.0612
[9,   110] loss: 0.6081
[10,    10] loss: 0.0604
[10,   110] loss: 0.5939
[11,    10] loss: 0.0580


 12%|█▏        | 12/100 [00:01<00:09,  9.47it/s]

[11,   110] loss: 0.5828
[12,    10] loss: 0.0576
[12,   110] loss: 0.5696
[13,    10] loss: 0.0559


 13%|█▎        | 13/100 [00:01<00:09,  9.41it/s]

[13,   110] loss: 0.5577
[14,    10] loss: 0.0545
[14,   110] loss: 0.5457
[15,    10] loss: 0.0544
[15,   110] loss: 0.5342


 16%|█▌        | 16/100 [00:01<00:08,  9.76it/s]

[16,    10] loss: 0.0535
[16,   110] loss: 0.5184
[17,    10] loss: 0.0522
[17,   110] loss: 0.5056
[18,    10] loss: 0.0506


 19%|█▉        | 19/100 [00:02<00:08,  9.81it/s]

[18,   110] loss: 0.5015
[19,    10] loss: 0.0493
[19,   110] loss: 0.4862
[20,    10] loss: 0.0467


 20%|██        | 20/100 [00:02<00:08,  9.65it/s]

[20,   110] loss: 0.4779
[21,    10] loss: 0.0488
[21,   110] loss: 0.4697
[22,    10] loss: 0.0468
[22,   110] loss: 0.4539


 24%|██▍       | 24/100 [00:02<00:07, 10.21it/s]

[23,    10] loss: 0.0452
[23,   110] loss: 0.4470
[24,    10] loss: 0.0444
[24,   110] loss: 0.4384
[25,    10] loss: 0.0401


 26%|██▌       | 26/100 [00:02<00:07,  9.61it/s]

[25,   110] loss: 0.4319
[26,    10] loss: 0.0428
[26,   110] loss: 0.4224
[27,    10] loss: 0.0415


 28%|██▊       | 28/100 [00:03<00:07,  9.38it/s]

[27,   110] loss: 0.4160
[28,    10] loss: 0.0400
[28,   110] loss: 0.4075
[29,    10] loss: 0.0394


 30%|███       | 30/100 [00:03<00:07,  8.83it/s]

[29,   110] loss: 0.3989
[30,    10] loss: 0.0358
[30,   110] loss: 0.3967
[31,    10] loss: 0.0387


 32%|███▏      | 32/100 [00:03<00:07,  8.53it/s]

[31,   110] loss: 0.3831
[32,    10] loss: 0.0391
[32,   110] loss: 0.3758
[33,    10] loss: 0.0360


 34%|███▍      | 34/100 [00:03<00:07,  8.90it/s]

[33,   110] loss: 0.3735
[34,    10] loss: 0.0365
[34,   110] loss: 0.3687
[35,    10] loss: 0.0365


 36%|███▌      | 36/100 [00:03<00:07,  8.73it/s]

[35,   110] loss: 0.3686
[36,    10] loss: 0.0353
[36,   110] loss: 0.3594
[37,    10] loss: 0.0337


 38%|███▊      | 38/100 [00:04<00:07,  8.59it/s]

[37,   110] loss: 0.3690
[38,    10] loss: 0.0353
[38,   110] loss: 0.3490
[39,    10] loss: 0.0336


 40%|████      | 40/100 [00:04<00:07,  8.43it/s]

[39,   110] loss: 0.3481
[40,    10] loss: 0.0344
[40,   110] loss: 0.3399
[41,    10] loss: 0.0347


 42%|████▏     | 42/100 [00:04<00:06,  8.59it/s]

[41,   110] loss: 0.3346
[42,    10] loss: 0.0377
[42,   110] loss: 0.3325
[43,    10] loss: 0.0322


 44%|████▍     | 44/100 [00:04<00:06,  8.26it/s]

[43,   110] loss: 0.3281
[44,    10] loss: 0.0340
[44,   110] loss: 0.3311
[45,    10] loss: 0.0307


 46%|████▌     | 46/100 [00:05<00:06,  8.31it/s]

[45,   110] loss: 0.3265
[46,    10] loss: 0.0327
[46,   110] loss: 0.3246
[47,    10] loss: 0.0331


 48%|████▊     | 48/100 [00:05<00:06,  8.55it/s]

[47,   110] loss: 0.3181
[48,    10] loss: 0.0321
[48,   110] loss: 0.3190
[49,    10] loss: 0.0327


 49%|████▉     | 49/100 [00:05<00:05,  8.81it/s]

[49,   110] loss: 0.3127
[50,    10] loss: 0.0318
[50,   110] loss: 0.3221
[51,    10] loss: 0.0267
[51,   110] loss: 0.3138


 52%|█████▏    | 52/100 [00:05<00:05,  9.25it/s]

[52,    10] loss: 0.0327
[52,   110] loss: 0.3175
[53,    10] loss: 0.0343
[53,   110] loss: 0.3137
[54,    10] loss: 0.0330

 54%|█████▍    | 54/100 [00:05<00:04,  9.57it/s]


[54,   110] loss: 0.3087
[55,    10] loss: 0.0280
[55,   110] loss: 0.3051


 57%|█████▋    | 57/100 [00:06<00:04,  9.72it/s]

[56,    10] loss: 0.0277
[56,   110] loss: 0.3044
[57,    10] loss: 0.0291
[57,   110] loss: 0.3138
[58,    10] loss: 0.0299

 59%|█████▉    | 59/100 [00:06<00:04, 10.02it/s]


[58,   110] loss: 0.3117
[59,    10] loss: 0.0328
[59,   110] loss: 0.3039
[60,    10] loss: 0.0277


 61%|██████    | 61/100 [00:06<00:03, 10.22it/s]

[60,   110] loss: 0.3058
[61,    10] loss: 0.0283
[61,   110] loss: 0.2890
[62,    10] loss: 0.0302
[62,   110] loss: 0.2927


 63%|██████▎   | 63/100 [00:06<00:03, 10.28it/s]

[63,    10] loss: 0.0354
[63,   110] loss: 0.2981
[64,    10] loss: 0.0315
[64,   110] loss: 0.3009
[65,    10] loss: 0.0284


 65%|██████▌   | 65/100 [00:07<00:03, 10.39it/s]

[65,   110] loss: 0.2940
[66,    10] loss: 0.0300
[66,   110] loss: 0.2947
[67,    10] loss: 0.0270
[67,   110] loss: 0.3139


 69%|██████▉   | 69/100 [00:07<00:03, 10.32it/s]

[68,    10] loss: 0.0291
[68,   110] loss: 0.2850
[69,    10] loss: 0.0327
[69,   110] loss: 0.2878
[70,    10] loss: 0.0329
[70,   110] loss: 0.2937
[71,    10] loss: 0.0266


 71%|███████   | 71/100 [00:07<00:03,  7.83it/s]

[71,   110] loss: 0.2907
[72,    10] loss: 0.0268
[72,   110] loss: 0.2793


 73%|███████▎  | 73/100 [00:08<00:04,  6.56it/s]

[73,    10] loss: 0.0332
[73,   110] loss: 0.2870
[74,    10] loss: 0.0283


 74%|███████▍  | 74/100 [00:08<00:04,  6.14it/s]

[74,   110] loss: 0.2741
[75,    10] loss: 0.0315
[75,   110] loss: 0.2856


 76%|███████▌  | 76/100 [00:08<00:04,  5.55it/s]

[76,    10] loss: 0.0290
[76,   110] loss: 0.2804
[77,    10] loss: 0.0304


 77%|███████▋  | 77/100 [00:09<00:04,  5.33it/s]

[77,   110] loss: 0.2891


 78%|███████▊  | 78/100 [00:09<00:04,  5.22it/s]

[78,    10] loss: 0.0292
[78,   110] loss: 0.2885


 79%|███████▉  | 79/100 [00:09<00:04,  5.13it/s]

[79,    10] loss: 0.0277
[79,   110] loss: 0.2831


 80%|████████  | 80/100 [00:09<00:03,  5.09it/s]

[80,    10] loss: 0.0280
[80,   110] loss: 0.2940
[81,    10] loss: 0.0317

 81%|████████  | 81/100 [00:09<00:03,  5.05it/s]


[81,   110] loss: 0.2911


 82%|████████▏ | 82/100 [00:10<00:03,  5.01it/s]

[82,    10] loss: 0.0297
[82,   110] loss: 0.2910


 83%|████████▎ | 83/100 [00:10<00:03,  4.99it/s]

[83,    10] loss: 0.0280
[83,   110] loss: 0.2850


 84%|████████▍ | 84/100 [00:10<00:03,  4.98it/s]

[84,    10] loss: 0.0266
[84,   110] loss: 0.2906


 85%|████████▌ | 85/100 [00:10<00:03,  4.97it/s]

[85,    10] loss: 0.0285
[85,   110] loss: 0.2916


 86%|████████▌ | 86/100 [00:10<00:02,  4.95it/s]

[86,    10] loss: 0.0259
[86,   110] loss: 0.2932


 87%|████████▋ | 87/100 [00:11<00:02,  4.94it/s]

[87,    10] loss: 0.0303
[87,   110] loss: 0.2778


 88%|████████▊ | 88/100 [00:11<00:02,  4.94it/s]

[88,    10] loss: 0.0375
[88,   110] loss: 0.2766


 89%|████████▉ | 89/100 [00:11<00:02,  4.94it/s]

[89,    10] loss: 0.0277
[89,   110] loss: 0.2907


 90%|█████████ | 90/100 [00:11<00:02,  4.92it/s]

[90,    10] loss: 0.0268
[90,   110] loss: 0.2868


 91%|█████████ | 91/100 [00:11<00:01,  4.88it/s]

[91,    10] loss: 0.0282
[91,   110] loss: 0.2800


 92%|█████████▏| 92/100 [00:12<00:01,  4.87it/s]

[92,    10] loss: 0.0258
[92,   110] loss: 0.2885


 93%|█████████▎| 93/100 [00:12<00:01,  4.87it/s]

[93,    10] loss: 0.0281
[93,   110] loss: 0.2857


 94%|█████████▍| 94/100 [00:12<00:01,  4.87it/s]

[94,    10] loss: 0.0285
[94,   110] loss: 0.2874


 95%|█████████▌| 95/100 [00:12<00:01,  4.87it/s]

[95,    10] loss: 0.0265
[95,   110] loss: 0.2838


 96%|█████████▌| 96/100 [00:12<00:00,  4.86it/s]

[96,    10] loss: 0.0279
[96,   110] loss: 0.2763


 97%|█████████▋| 97/100 [00:13<00:00,  4.87it/s]

[97,    10] loss: 0.0320
[97,   110] loss: 0.2711


 98%|█████████▊| 98/100 [00:13<00:00,  4.86it/s]

[98,    10] loss: 0.0246
[98,   110] loss: 0.2776


 99%|█████████▉| 99/100 [00:13<00:00,  4.85it/s]

[99,    10] loss: 0.0348
[99,   110] loss: 0.2732


100%|██████████| 100/100 [00:13<00:00,  7.27it/s]

[100,    10] loss: 0.0246
[100,   110] loss: 0.2898
Finished Training





In [772]:
# Save the trained model
torch.save(model.state_dict(), name)

# Perform inference with the saved model
loaded_model = CNNClassifier()
loaded_model.load_state_dict(torch.load(name))
loaded_model.eval()

CNNClassifier(
  (fc1): Linear(in_features=768, out_features=4, bias=True)
  (fc2): Linear(in_features=4, out_features=1, bias=True)
)

In [773]:
# Lets do testing

# Example inference
# sample_input = torch.randn(1, 1, 768)  # Example input
# output = loaded_model(sample_input)
# print("Model output:", output.item())

# Read pkl file 
with open('embeddings/rotten-MPNET/rotten-test-768.pkl', 'rb') as f:
    test_vectors = pickle.load(f)

In [774]:
gt = []
preds = []

for i in tqdm.tqdm(range(len(test_vectors))):
    input_vector = torch.tensor(test_vectors[i].reshape(1, 1, 768))
    actual = test_dataset[i]['label']
    output = loaded_model(input_vector) 
    if output < 0.5:
        pred = 0
    else:
        pred = 1
    gt.append(actual)
    preds.append(pred)

100%|██████████| 1066/1066 [00:00<00:00, 13721.51it/s]


In [775]:
# Print the number of misclassifications
misc = 0
for i in range(len(gt)):
    if gt[i] != preds[i]:
        misc += 1

print(misc)

159


In [776]:
from sklearn.metrics import classification_report

# Generate classification report
report = classification_report(gt, preds)

print("Classification Report:")
print(report)

Classification Report:
              precision    recall  f1-score   support

           0       0.85      0.86      0.85       533
           1       0.86      0.84      0.85       533

    accuracy                           0.85      1066
   macro avg       0.85      0.85      0.85      1066
weighted avg       0.85      0.85      0.85      1066

