In [1]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch
from tqdm import tqdm

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
if torch.cuda.is_available():
  device = torch.device("cuda:0")
  print("running on the GPU")
else:
  device = torch.device("cpu")
  print("running on the CPU")

running on the GPU


In [4]:
training_data = np.load("/content/drive/My Drive/training_data.npy", allow_pickle=True)

In [6]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(300, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.fc4 = nn.Linear(256, 2)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = self.fc4(x)
        return F.softmax(x, dim=1)

In [7]:
tr_size = round(len(training_data)*0.8)
tr_set = training_data[:tr_size]
te_set = training_data[tr_size:]

In [8]:
net = Net().to(device)
print(net)

Net(
  (fc1): Linear(in_features=300, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=2, bias=True)
)


In [26]:
BATCH_SIZE = 100
def train(net):
  EPOCHS = 30*8
  loss_function = nn.MSELoss()
  optimizer = optim.Adam(net.parameters(), lr=0.001)
  loss = 1000000000
  test_accuracy_at_each_epoch = []
  for epoch in range(EPOCHS):
      for i in tqdm(range(0, len(tr_set), BATCH_SIZE)):
          batch_X = [x for x, y_ in tr_set[i:i+BATCH_SIZE]]
          batch_y = [y_ for x, y_ in tr_set[i:i+BATCH_SIZE]]

          tensor_batch_X = torch.tensor(batch_X).float()
          tensor_batch_y = torch.tensor(batch_y).float()

          tensor_batch_X, tensor_batch_y = tensor_batch_X.to(device), tensor_batch_y.to(device)

          net.zero_grad()

          output = net(tensor_batch_X.view(-1, 300))
          loss = loss_function(output, tensor_batch_y)
          loss.backward()
          optimizer.step()
      
      accuracy = test(net)
      test_accuracy_at_each_epoch.append(accuracy)
      print(f"\nLoss: {loss}")
      print(f"\nAccuracy: {accuracy}")
  return test_accuracy_at_each_epoch

In [None]:
test_accuracy_epoch = train(net)

100%|██████████| 18164/18164 [02:03<00:00, 146.89it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.98it/s]



Loss: 0.15786199271678925


100%|██████████| 18164/18164 [02:00<00:00, 151.17it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.56it/s]



Loss: 0.16346590220928192


100%|██████████| 18164/18164 [01:59<00:00, 151.45it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.89it/s]



Loss: 0.16107496619224548


100%|██████████| 18164/18164 [02:11<00:00, 138.63it/s]
100%|██████████| 4541/4541 [00:51<00:00, 88.35it/s]



Loss: 0.1566895693540573


100%|██████████| 18164/18164 [02:01<00:00, 149.18it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.44it/s]



Loss: 0.1550334393978119


100%|██████████| 18164/18164 [02:00<00:00, 151.17it/s]
100%|██████████| 4541/4541 [00:47<00:00, 94.62it/s]



Loss: 0.15701739490032196


100%|██████████| 18164/18164 [02:00<00:00, 150.34it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.95it/s]



Loss: 0.1604745090007782


100%|██████████| 18164/18164 [02:02<00:00, 148.83it/s]
100%|██████████| 4541/4541 [00:48<00:00, 92.78it/s]



Loss: 0.15262307226657867


100%|██████████| 18164/18164 [02:01<00:00, 148.94it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.46it/s]



Loss: 0.15765908360481262


100%|██████████| 18164/18164 [02:00<00:00, 150.78it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.15it/s]



Loss: 0.15651848912239075


100%|██████████| 18164/18164 [02:01<00:00, 149.38it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.81it/s]



Loss: 0.1553516983985901


100%|██████████| 18164/18164 [02:02<00:00, 148.43it/s]
100%|██████████| 4541/4541 [00:49<00:00, 92.60it/s]



Loss: 0.15505467355251312


100%|██████████| 18164/18164 [02:01<00:00, 149.63it/s]
100%|██████████| 4541/4541 [00:48<00:00, 92.79it/s]



Loss: 0.15511454641819


100%|██████████| 18164/18164 [02:01<00:00, 149.06it/s]
100%|██████████| 4541/4541 [00:49<00:00, 92.56it/s]



Loss: 0.1552474945783615


100%|██████████| 18164/18164 [02:02<00:00, 148.12it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.15it/s]



Loss: 0.14948907494544983


100%|██████████| 18164/18164 [02:01<00:00, 149.76it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.45it/s]



Loss: 0.14549556374549866


100%|██████████| 18164/18164 [02:01<00:00, 149.92it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.11it/s]



Loss: 0.1460811197757721


100%|██████████| 18164/18164 [02:00<00:00, 150.17it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.93it/s]



Loss: 0.1449686586856842


100%|██████████| 18164/18164 [02:01<00:00, 149.36it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.75it/s]



Loss: 0.14404012262821198


100%|██████████| 18164/18164 [02:01<00:00, 149.64it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.81it/s]



Loss: 0.14466135203838348


100%|██████████| 18164/18164 [02:01<00:00, 149.25it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.13it/s]



Loss: 0.14692100882530212


100%|██████████| 18164/18164 [02:02<00:00, 147.71it/s]
100%|██████████| 4541/4541 [00:48<00:00, 92.84it/s]



Loss: 0.1433178037405014


100%|██████████| 18164/18164 [02:02<00:00, 148.85it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.57it/s]



Loss: 0.1426757574081421


100%|██████████| 18164/18164 [01:59<00:00, 151.84it/s]
100%|██████████| 4541/4541 [00:47<00:00, 94.97it/s]



Loss: 0.1427626758813858


100%|██████████| 18164/18164 [02:00<00:00, 151.13it/s]
100%|██████████| 4541/4541 [00:47<00:00, 95.00it/s]



Loss: 0.14095087349414825


100%|██████████| 18164/18164 [02:00<00:00, 150.31it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.66it/s]



Loss: 0.13985389471054077


100%|██████████| 18164/18164 [02:01<00:00, 149.09it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.14it/s]



Loss: 0.138587087392807


100%|██████████| 18164/18164 [01:59<00:00, 151.42it/s]
100%|██████████| 4541/4541 [00:47<00:00, 94.95it/s]



Loss: 0.1362352967262268


100%|██████████| 18164/18164 [01:59<00:00, 152.14it/s]
100%|██████████| 4541/4541 [00:47<00:00, 96.11it/s]



Loss: 0.13705763220787048


100%|██████████| 18164/18164 [02:00<00:00, 150.63it/s]
100%|██████████| 4541/4541 [00:47<00:00, 95.26it/s]



Loss: 0.13456198573112488


100%|██████████| 18164/18164 [02:00<00:00, 150.35it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.63it/s]



Loss: 0.13842658698558807


100%|██████████| 18164/18164 [02:00<00:00, 150.19it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.47it/s]



Loss: 0.13511739671230316


100%|██████████| 18164/18164 [02:00<00:00, 150.27it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.32it/s]



Loss: 0.13420532643795013


100%|██████████| 18164/18164 [02:00<00:00, 150.57it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.40it/s]



Loss: 0.13494913280010223


100%|██████████| 18164/18164 [01:59<00:00, 151.58it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.60it/s]



Loss: 0.1341399848461151


100%|██████████| 18164/18164 [01:58<00:00, 153.06it/s]
100%|██████████| 4541/4541 [00:47<00:00, 95.43it/s]



Loss: 0.1290842741727829


100%|██████████| 18164/18164 [01:59<00:00, 151.65it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.48it/s]



Loss: 0.1369473934173584


100%|██████████| 18164/18164 [02:01<00:00, 149.65it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.52it/s]



Loss: 0.1357455551624298


100%|██████████| 18164/18164 [01:58<00:00, 152.66it/s]
100%|██████████| 4541/4541 [00:47<00:00, 95.12it/s]



Loss: 0.13116401433944702


100%|██████████| 18164/18164 [02:01<00:00, 149.07it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.34it/s]



Loss: 0.12969447672367096


100%|██████████| 18164/18164 [02:01<00:00, 149.55it/s]
100%|██████████| 4541/4541 [00:48<00:00, 93.12it/s]



Loss: 0.13907967507839203


100%|██████████| 18164/18164 [02:01<00:00, 149.87it/s]
100%|██████████| 4541/4541 [00:48<00:00, 94.14it/s]



Loss: 0.1359969675540924


100%|██████████| 18164/18164 [02:02<00:00, 148.05it/s]
100%|██████████| 4541/4541 [00:47<00:00, 94.65it/s]



Loss: 0.13932174444198608


100%|██████████| 18164/18164 [01:59<00:00, 151.84it/s]
100%|██████████| 4541/4541 [00:47<00:00, 94.78it/s]



Loss: 0.13560332357883453


100%|██████████| 18164/18164 [01:59<00:00, 151.97it/s]
100%|██████████| 4541/4541 [00:47<00:00, 95.67it/s]



Loss: 0.1320076882839203


100%|██████████| 18164/18164 [01:59<00:00, 152.18it/s]
100%|██████████| 4541/4541 [00:47<00:00, 94.79it/s]



Loss: 0.13440537452697754


100%|██████████| 18164/18164 [01:59<00:00, 151.85it/s]
100%|██████████| 4541/4541 [00:47<00:00, 94.62it/s]



Loss: 0.13477829098701477


100%|██████████| 18164/18164 [01:59<00:00, 152.12it/s]
100%|██████████| 4541/4541 [00:47<00:00, 95.34it/s]



Loss: 0.1328144669532776


100%|██████████| 18164/18164 [01:59<00:00, 152.18it/s]
100%|██████████| 4541/4541 [00:47<00:00, 95.17it/s]



Loss: 0.13253876566886902


100%|██████████| 18164/18164 [01:59<00:00, 152.19it/s]
100%|██████████| 4541/4541 [00:47<00:00, 95.24it/s]



Loss: 0.13032759726047516


100%|██████████| 18164/18164 [02:01<00:00, 149.19it/s]
 48%|████▊     | 2166/4541 [00:22<00:25, 92.38it/s]

In [None]:
print(test_accuracy_epoch)

In [12]:
def test(net):
    correct = 0
    total = 0
    with torch.no_grad():
        for i in tqdm(range(0, len(te_set), BATCH_SIZE)):
            batch_X = [x for x, y_ in te_set[i:i+BATCH_SIZE]]
            batch_y = [y_ for x, y_ in te_set[i:i+BATCH_SIZE]]

            tensor_batch_X = torch.tensor(batch_X).float().view(-1, 300).to(device)
            tensor_batch_y = torch.tensor(batch_y).float().to(device)
            batch_out = net(tensor_batch_X)

            out_maxes = [torch.argmax(val) for val in batch_out]
            target_maxes = [torch.argmax(val) for val in tensor_batch_y]
            
            for i, j in zip(out_maxes, target_maxes):
              if i == j:
                correct +=1
              total += 1
    return round(correct/total, 3)

In [18]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(list(range(1,30*8+1)), test_accuracy_epoch)
plt.show()

In [None]:
torch.save(net.state_dict(), "/content/drive/My Drive/nn_dict")
torch.save(net, "/content/drive/My Drive/nn_object")