In [236]:
import pickle
import numpy as np
from datasets import load_dataset
import tqdm
import math
import sys

In [237]:
dataset = load_dataset("glue", "sst2")
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]

### CNN Based Classifier

In [238]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

In [239]:
# Define the CNN architecture for 1D vectors
class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.conv1 = nn.Conv1d(in_channels = 1, out_channels = 16, kernel_size=3)
        self.conv2 = nn.Conv1d(in_channels = 16, out_channels = 16, kernel_size=3)
        self.pool = nn.MaxPool1d(kernel_size=2)
        self.fc1 = nn.Linear(16 * 190, 128)  # Adjust input size according to your data
        self.fc2 = nn.Linear(128, 1)      # Output layer with 1 neuron for binary classification

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 16 * 190)  # Adjust the size according to the output size of the last convolutional layer
        x = nn.functional.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))  # Sigmoid activation for binary classification
        return x.unsqueeze(-1)

In [240]:
# Sample data loader for demonstration
class SampleDataset(torch.utils.data.Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [241]:
# Get Training Data

# Read pkl file 
with open('sst-train-768.pkl', 'rb') as f:
    vectors = pickle.load(f)
    
X = np.array(vectors)

labels = []
for item in train_dataset:
    labels.append(item['label'])
    
Y = np.array(labels)

### Determining which Subset 

In [415]:
# Read Instance Scores
with open('scores/instance-scores-L0.pkl', 'rb') as f:
    instance_scores = pickle.load(f)

In [434]:
X_eff = []
Y_eff = []

M = len(instance_scores)
for i in range(M):
    if instance_scores[i] > 0:
        X_eff.append(X[i])
        Y_eff.append(Y[i])

In [435]:
N = len(X_eff)
X_train = torch.from_numpy(np.array(X_eff).reshape(N, 1, 768))
Y_train = torch.from_numpy(np.array(Y_eff).reshape(N, 1, 1)).float()
print(str(N) + ' samples selected')

58896 samples selected


In [436]:
c0 = 0
c1 = 0
for i in Y_eff:
    if i == 0:
        c0 += 1
    else:
        c1 += 1
        
print('Class 0: ' + str(c0))
print('Class 1: ' + str(c1))

Class 0: 26390
Class 1: 32506


### Start Training

In [437]:
# Initialize the model, loss function, and optimizer
model = CNNClassifier()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr = 0.0001)

# Create data loader
trainset = SampleDataset(X_train, Y_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size = 16, shuffle=True)

# Train the model
for epoch in range(3):  # Adjust number of epochs as needed
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i % 100 == 9:  # Print every 10 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 10))
            running_loss = 0.0

print('Finished Training')

[1,    10] loss: 0.691
[1,   110] loss: 6.873
[1,   210] loss: 6.840
[1,   310] loss: 6.857
[1,   410] loss: 6.858
[1,   510] loss: 6.856
[1,   610] loss: 6.823
[1,   710] loss: 6.853
[1,   810] loss: 6.836
[1,   910] loss: 6.783
[1,  1010] loss: 6.769
[1,  1110] loss: 6.709
[1,  1210] loss: 6.633
[1,  1310] loss: 6.459
[1,  1410] loss: 6.296
[1,  1510] loss: 5.966
[1,  1610] loss: 5.480
[1,  1710] loss: 4.874
[1,  1810] loss: 4.193
[1,  1910] loss: 3.661
[1,  2010] loss: 3.227
[1,  2110] loss: 2.842
[1,  2210] loss: 2.409
[1,  2310] loss: 2.263
[1,  2410] loss: 1.962
[1,  2510] loss: 1.965
[1,  2610] loss: 1.816
[1,  2710] loss: 1.682
[1,  2810] loss: 1.727
[1,  2910] loss: 1.582
[1,  3010] loss: 1.538
[1,  3110] loss: 1.397
[1,  3210] loss: 1.415
[1,  3310] loss: 1.353
[1,  3410] loss: 1.285
[1,  3510] loss: 1.285
[1,  3610] loss: 1.274
[2,    10] loss: 0.115
[2,   110] loss: 1.225
[2,   210] loss: 1.222
[2,   310] loss: 1.196
[2,   410] loss: 1.142
[2,   510] loss: 0.976
[2,   610] 

In [438]:
# Save the trained model
torch.save(model.state_dict(), 'models/model-L0-0.pth')

In [439]:
# Perform inference with the saved model
loaded_model = CNNClassifier()
loaded_model.load_state_dict(torch.load('models/model-L0-0.pth'))
loaded_model.eval()

CNNClassifier(
  (conv1): Conv1d(1, 16, kernel_size=(3,), stride=(1,))
  (conv2): Conv1d(16, 16, kernel_size=(3,), stride=(1,))
  (pool): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=3040, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=1, bias=True)
)

In [440]:
# Example inference
sample_input = torch.randn(1, 1, 768)  # Example input
output = loaded_model(sample_input)
print("Model output:", output.item())

Model output: 1.0


In [441]:
# Lets do testing

# Read pkl file 
with open('sst-val-768.pkl', 'rb') as f:
    test_vectors = pickle.load(f)


In [442]:
gt = []
preds = []

for i in range(len(test_vectors)):
    input_vector = torch.tensor(test_vectors[i].reshape(1, 1, 768))
    actual = val_dataset[i]['label']
    output = loaded_model(input_vector) 
    if output < 0.5:
        pred = 0
    else:
        pred = 1
    gt.append(actual)
    preds.append(pred)

In [443]:
# Print the number of misclassifications
misc = 0
for i in range(len(gt)):
    if gt[i] != preds[i]:
        misc += 1

print(misc)

113


In [444]:
from sklearn.metrics import classification_report

# Generate classification report
report = classification_report(gt, preds)

print("Classification Report:")
print(report)

Classification Report:
              precision    recall  f1-score   support

           0       0.89      0.84      0.86       428
           1       0.85      0.90      0.88       444

    accuracy                           0.87       872
   macro avg       0.87      0.87      0.87       872
weighted avg       0.87      0.87      0.87       872

