In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim


from nltk.tokenize import sent_tokenize
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformers import BertTokenizer,BertModel

from torch.utils.data import Dataset

from datasets import load_from_disk,load_dataset

import pickle

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda
NVIDIA GeForce RTX 3060 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


In [3]:
newsgroup_configs = ['bydate_alt.atheism',
                     'bydate_comp.graphics',
                     'bydate_comp.os.ms-windows.misc',
                     'bydate_comp.sys.ibm.pc.hardware',
                     'bydate_comp.sys.mac.hardware',
                     'bydate_comp.windows.x',
                     'bydate_misc.forsale',
                     'bydate_rec.autos',
                     'bydate_rec.motorcycles',
                     'bydate_rec.sport.baseball',
                     'bydate_rec.sport.hockey',
                     'bydate_sci.crypt',
                     'bydate_sci.electronics',
                     'bydate_sci.med',
                     'bydate_sci.space',
                     'bydate_soc.religion.christian',
                     'bydate_talk.politics.guns',
                     'bydate_talk.politics.mideast',
                     'bydate_talk.politics.misc',
                     'bydate_talk.religion.misc']

In [4]:
threshold = 1.0


In [76]:
dataset_list = []
split = 'train'
#split = 'test'
for config in newsgroup_configs:
    subset_path = r'\\wsl$\Ubuntu-20.04\home\jolteon\NLUProject\data\20news\\'+ split+ '\\'+ config
    dataset_list.append((config,load_from_disk(subset_path)))

In [72]:
label_to_cutoff_indices_file = r'\\wsl$\Ubuntu-20.04\home\jolteon\NLUProject\data\20news\processed\train\label_to_cutoff_indices_'+str(threshold)+'.pkl'
with open(label_to_cutoff_indices_file, 'rb') as handle:
    label_to_cutoff_indices_dict = pickle.load(handle)

In [95]:
for key in  label_to_cutoff_indices_dict['bydate_alt.atheism'].keys():
    print(len(label_to_cutoff_indices_dict['bydate_alt.atheism'][key]))


95
316
49
14
5
54
7
12
28
3
6
5
12
26
30
13
7
11
19
13
3
7
100
6
6
6
5
5
15
4
8
23
6
35
8
11
63
15
5
5
8
10
24
10
16
35
12
2
5
22
25
4
26
31
13
7
9
8
9
30
16
7
6
5
22
13
10
7
8
19
14
10
6
20
6
18
9
17
12
17
10
8
13
10
56
21
18
123
24
16
6
5
5
26
16
6
25
6
13
12
16
7
22
5
12
8
9
28
9
14
12
25
12
4
141
18
11
9
14
24
7
9
12
11
14
34
4
14
30
12
7
28
10
17
56
15
20
7
3
7
24
5
33
10
20
8
11
19
28
10
48
6
18
13
6
33
7
11
6
8
12
78
8
28
15
14
27
16
32
18
13
20
63
15
17
89
12
9
42
58
8
6
9
12
4
4
9
28
3
46
82
18
237
245
40
8
6
3
46
13
9
16
4
100
27
7
6
18
13
15
19
3
10
6
13
23
11
3
47
12
13
10
6
10
11
10
15
6
21
6
22
24
12
11
17
31
15
19
7
45
21
9
6
11
8
9
7
7
15
55
28
7
31
30
4
8
5
12
7
16
11
19
12
12
33
3
21
45
16
16
20
8
5
3
12
12
19
9
12
13
8
11
16
19
17
26
49
13
13
16
26
31
7
25
9
21
26
12
7
24
15
26
11
17
17
9
25
8
17
37
14
14
11
12
9
33
11
7
6
10
19
14
49
8
8
6
7
8
5
12
24
21
5
6
9
7
11
13
14
10
9
21
8
26
45
14
28
9
20
31
5
38
11
4
32
41
22
12
11
123
58
9
18
8
15
5
6
9
10
8
4
192
3
2
69


In [84]:
temp = dataset_list[0][1]['text'][0]

In [85]:
sentence_list = sent_tokenize(temp)

In [96]:
sentence_list

['From: mathew <mathew@mantis.co.uk>\nSubject: Alt.Atheism FAQ: Atheist Resources\nSummary: Books, addresses, music -- anything related to atheism\nKeywords: FAQ, atheism, books, music, fiction, addresses, contacts\nExpires: Thu, 29 Apr 1993 11:57:19 GMT\nDistribution: world\nOrganization: Mantis Consultants, Cambridge.',
 'UK.',
 'Supersedes: <19930301143317@mantis.co.uk>\nLines: 290\n\nArchive-name: atheism/resources\nAlt-atheism-archive-name: resources\nLast-modified: 11 December 1992\nVersion: 1.0\n\n                              Atheist Resources\n\n                      Addresses of Atheist Organizations\n\n                                     USA\n\nFREEDOM FROM RELIGION FOUNDATION\n\nDarwin fish bumper stickers and assorted other atheist paraphernalia are\navailable from the Freedom From Religion Foundation in the US.',
 'Write to:  FFRF, P.O.',
 'Box 750, Madison, WI 53701.',
 'Telephone: (608) 256-8900\n\nEVOLUTION DESIGNS\n\nEvolution Designs sell the "Darwin fish".',
 'It\'

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…




In [14]:
with open(processed_dir+ split+'\\' + file_name + str(threshold) +'.pkl', 'rb') as handle:
    bert_encoded_segments_list = pickle.load(handle)
 

In [15]:
for ii, batch in enumerate(bert_encoded_segments_list):
    label = batch[0]
    temp = batch[1]
    print(label)
    print(temp.shape)
    if ii == 10:
        break

tensor([0])
torch.Size([96, 768])
tensor([0])
torch.Size([317, 768])
tensor([0])
torch.Size([50, 768])
tensor([0])
torch.Size([15, 768])
tensor([0])
torch.Size([6, 768])
tensor([0])
torch.Size([55, 768])
tensor([0])
torch.Size([8, 768])
tensor([0])
torch.Size([13, 768])
tensor([0])
torch.Size([29, 768])
tensor([0])
torch.Size([4, 768])
tensor([0])
torch.Size([7, 768])


In [16]:
class EncodedSegmentsDataset(Dataset):
    def __init__(self,data_list):
        self.data_list = data_list
        
    def __len__(self):
        return len(self.data_list)
 
    def __getitem__(self,idx):
        return(self.data_list[idx])

In [17]:
encoded_dataset = EncodedSegmentsDataset(bert_encoded_segments_list)
val_prop =.1
bsize = 1

dataset_size = len(encoded_dataset)
val_size = int(val_prop * dataset_size)
train_size = dataset_size - val_size

train_dataset, val_dataset =  torch.utils.data.random_split(encoded_dataset,[train_size,val_size])
encoded_train_loader = DataLoader(train_dataset,batch_size=bsize,shuffle=True, pin_memory=True)
encoded_val_loader = DataLoader(val_dataset,batch_size=bsize,shuffle=True, pin_memory=True)



In [70]:
for ii, batch in enumerate(encoded_train_loader):
    label = batch[0]
    tensor = batch[1]
    print(label)
    print(tensor.shape)
    print(tensor[0,0,10:20])
    if ii ==2:
        break

tensor([[11]])
torch.Size([1, 7, 768])
tensor([-0.5196,  0.9922,  0.1127,  0.9911,  0.0941, -0.6303, -0.1405, -0.3908,
         0.3025,  0.5824])
tensor([[17]])
torch.Size([1, 23, 768])
tensor([-0.4757,  0.9918,  0.0328,  0.9892,  0.0400, -0.6188, -0.1452, -0.3251,
         0.2685,  0.5536])
tensor([[5]])
torch.Size([1, 6, 768])
tensor([-0.6220,  0.9918,  0.5029,  0.9865,  0.4022, -0.6757, -0.1405, -0.4920,
         0.3431,  0.6357])


In [67]:
class LSTMoverBERT(nn.Module):
    def __init__(self):
        super().__init__()
        self.LSTM = nn.LSTM(input_size=768,hidden_size = 128,num_layers=1,batch_first=True)
        self.activation = nn.ReLU()
        self.linear1 = nn.Linear(in_features=128,out_features=64)
        self.linear2 = nn.Linear(in_features=64,out_features=20)

        self.softmax = nn.Softmax(dim=1)
        

    def forward(self, x,verbose=False):
        
        #print('input x:', x.shape)
        LSTM_out,LSTM_states = self.LSTM(x)
        #print('LSTM out:', LSTM_out.shape)
        #print('LSTM states[0]:', LSTM_states[0].shape)
        #print('LSTM states[1]:', LSTM_states[1].shape)
        last_hidden_state = LSTM_states[1][0]
        #last_embedding = LSTM_out[:,-1,:]
        out = self.linear1(last_hidden_state)
        #print('linear out', out.shape) if verbose
        out = self.activation(out)
        #print('activation out', out.shape) if verbose
        out = self.linear2(out)

        out= self.softmax(out)
        #raise Exception()
        return out



In [68]:
LoBERT_model = LSTMoverBERT()
LoBERT_model.to(device)
LoBERT_model.train()
criterion = nn.CrossEntropyLoss(reduction='sum')
optimizer = optim.Adam(LoBERT_model.parameters(),lr=0.000005,)

In [69]:
train_loss_list = []
val_loss_list = []
train_accuracy_list = []
val_accuracy_list = []

for epoch in tqdm(range(30)):  # loop over the dataset multiple times
    train_loss = 0
    train_correct = 0
    val_loss = 0
    val_correct = 0
    
    #START TRAIN
    LoBERT_model.train()
    for idx, batch in enumerate(encoded_train_loader):
        # Define and move to GPU
        label = batch[0][0]
        model_input = batch[1]
        label = label.to(device)
        model_input = model_input.to(device)
        # Forward Pass
        out = LoBERT_model(model_input)
        loss = criterion(out,label)
        #Record Metrics pt 1/2
        train_loss += loss.item()
        pred = torch.argmax(out)
        train_correct +=(pred == label).sum()
        
        #Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    #Record Metrics pt 2/2
    train_loss = train_loss/ len(encoded_train_loader)
    train_accuracy = train_correct/len(encoded_train_loader)
    #Print and save
    print('Epoch:', epoch, 'train_loss:',train_loss, 'accuracy: ',train_accuracy)
    train_loss_list.append(train_loss)
    train_accuracy_list.append(train_accuracy)
    
    # START VAL
    LoBERT_model.eval()
    
    with torch.no_grad():
        for idx, batch in enumerate(encoded_val_loader):
            #Define and move to GPU
            label = batch[0][0]
            model_input = batch[1]
            label = label.to(device)
            model_input = model_input.to(device)
            #Forward Pass
            out = LoBERT_model(model_input)
            loss = criterion(out,label)
            #Record metrics pt 1/2
            val_loss += loss.item()

            pred = torch.argmax(out)
            val_correct +=(pred == label).sum()
            
    #Record metrics pt 2/2
    val_loss = val_loss/ len(encoded_val_loader)
    val_accuracy = val_correct/len(encoded_val_loader)
    #Print and save
    print('Epoch:', epoch, 'val_loss:',val_loss, 'accuracy: ',val_accuracy)
    val_loss_list.append(val_loss)
    val_accuracy_list.append(val_accuracy)


HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

Epoch: 0 train_loss: 2.9950206851743824 accuracy:  tensor(0.0538, device='cuda:0')
Epoch: 0 val_loss: 2.994527119645811 accuracy:  tensor(0.0424, device='cuda:0')
Epoch: 1 train_loss: 2.9924446501362407 accuracy:  tensor(0.0574, device='cuda:0')
Epoch: 1 val_loss: 2.991097028228058 accuracy:  tensor(0.0637, device='cuda:0')
Epoch: 2 train_loss: 2.9879090178636813 accuracy:  tensor(0.0703, device='cuda:0')
Epoch: 2 val_loss: 2.9880221771203135 accuracy:  tensor(0.0769, device='cuda:0')
Epoch: 3 train_loss: 2.9853159153752955 accuracy:  tensor(0.0696, device='cuda:0')
Epoch: 3 val_loss: 2.9867270804421233 accuracy:  tensor(0.0725, device='cuda:0')
Epoch: 4 train_loss: 2.9844657698139323 accuracy:  tensor(0.0695, device='cuda:0')
Epoch: 4 val_loss: 2.9861227361424425 accuracy:  tensor(0.0734, device='cuda:0')
Epoch: 5 train_loss: 2.9834546213310174 accuracy:  tensor(0.0686, device='cuda:0')
Epoch: 5 val_loss: 2.9839841219095082 accuracy:  tensor(0.0734, device='cuda:0')
Epoch: 6 train_los