In [None]:
#1D CNN Test
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

class CNN(nn.Module):
    def __init__(self):
      super(CNN, self).__init__()
      self.conv1 = nn.Conv1d(1, 64, 1)
      self.pool1 = nn.MaxPool1d(1,1)
      self.conv2 = nn.Conv1d(64, 80, 1)
      self.pool2 = nn.MaxPool1d(1,1)
      self.act1 = nn.ReLU()
      self.act2 = nn.LeakyReLU()
      #FCL
      self.fl1 = nn.Linear(1040, 512)
      self.fl2 = nn.Linear(512, 220)
      self.fl3 = nn.Linear(220, 3)

    def forward(self, x):
      x = x.view(x.size(0), 1, 13) #(batch_size, num_channels, num_data_per_channel)
      x = self.conv1(x)
      x = self.act1(x)
      x = self.pool1(x)
      x = self.conv2(x)
      x = self.act2(x)
      x = self.pool2(x)
      x = x.reshape(x.shape[0], -1) #Flatten
      x = self.fl1(x)
      x = self.act1(x)
      x = self.fl2(x)
      x = self.act2(x)
      x = self.fl3(x)
      return x #CEL so no softmax

class WDat(Dataset):
  def __init__(self):
    self.xy = np.loadtxt("/content/wine.csv",delimiter=",",dtype=np.float32,skiprows=1)
    self.x = self.xy[:,1:]
    self.y = self.xy[:,0]
    self.x = torch.from_numpy(self.x)
    self.y = torch.from_numpy(self.y).long()-1
    self.nsam = len(self.x)

  def __len__(self):
    return self.nsam

  def __getitem__(self, idx):
    return self.x[idx], self.y[idx]

mydata = WDat()
myloader = DataLoader(mydata, batch_size=16, shuffle=True)
num_epochs = 60
net = CNN()
loss = nn.CrossEntropyLoss()
optimizer = optim.Adamax(net.parameters(), lr=0.001)

avg_acc,num_acc = 0,0
for e in range(num_epochs):
  avg_acc_e,num_acc_e = 0,0
  for (xin, yin) in myloader:
    y_pred = net(xin)
    _,cls = torch.max(y_pred, 1)
    acc = (cls == yin).sum().item()/len(yin)
    avg_acc += acc
    avg_acc_e += acc
    num_acc_e += 1
    num_acc += 1
    l = loss(y_pred, yin)
    optimizer.zero_grad()
    l.backward()
    optimizer.step()
  print(f"Epoch: {e+1} Loss: {l.item()} Avg Accuracy: {avg_acc_e/num_acc_e}");
print(f"Total Average Accuracy {avg_acc/num_acc}")

Epoch: 1 Loss: 1.8161826133728027 Avg Accuracy: 0.359375
Epoch: 2 Loss: 1.1336917877197266 Avg Accuracy: 0.3645833333333333
Epoch: 3 Loss: 0.6326448321342468 Avg Accuracy: 0.5364583333333334
Epoch: 4 Loss: 0.9056510925292969 Avg Accuracy: 0.65625
Epoch: 5 Loss: 0.8486098647117615 Avg Accuracy: 0.6145833333333334
Epoch: 6 Loss: 0.6760177612304688 Avg Accuracy: 0.6614583333333334
Epoch: 7 Loss: 0.48215997219085693 Avg Accuracy: 0.703125
Epoch: 8 Loss: 0.7024503946304321 Avg Accuracy: 0.6927083333333334
Epoch: 9 Loss: 0.7251882553100586 Avg Accuracy: 0.6875
Epoch: 10 Loss: 0.7142814993858337 Avg Accuracy: 0.796875
Epoch: 11 Loss: 0.47937220335006714 Avg Accuracy: 0.765625
Epoch: 12 Loss: 0.12516744434833527 Avg Accuracy: 0.7552083333333334
Epoch: 13 Loss: 1.6527082920074463 Avg Accuracy: 0.734375
Epoch: 14 Loss: 1.092834711074829 Avg Accuracy: 0.6302083333333334
Epoch: 15 Loss: 0.3312925696372986 Avg Accuracy: 0.8125
Epoch: 16 Loss: 0.5108392834663391 Avg Accuracy: 0.7708333333333334
Epoc