# Generative parameters

In [None]:
# import csv
# with open('dataset.0.csv', newline='') as csvfile:
#     spamreader = csv.reader(csvfile, delimiter='\t', quotechar='|')
#     for row in spamreader:
#         # print(', '.join(row))
#         print(len(row))

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim


dataset = np.loadtxt('dataset.0.csv', delimiter='\t')
X = dataset[:,0:64]
y = dataset[:,64:68]

X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.float32)

In [None]:
class SpectrumEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(64, 128)
        self.act1 = nn.ReLU()
        self.hidden2 = nn.Linear(128, 96)
        self.act2 = nn.ReLU()
        self.output = nn.Linear(96, 4)
        self.act_output = nn.Sigmoid()

    def forward(self, x):
        x = self.act1(self.hidden1(x))
        x = self.act2(self.hidden2(x))
        x = self.act_output(self.output(x))
        return x
    
model = SpectrumEncoder()
print(model)

# train the model
loss_fn   = nn.BCELoss()  # binary cross entropy
optimizer = optim.Adam(model.parameters(), lr=0.001)

n_epochs = 100
batch_size = 4

for epoch in range(n_epochs):
    for i in range(0, len(X), batch_size):
        Xbatch = X[i:i+batch_size]
        y_pred = model(Xbatch)
        ybatch = y[i:i+batch_size]
        loss = loss_fn(y_pred, ybatch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

# compute accuracy
y_pred = model(X)
accuracy = (y_pred.round() == y).float().mean()
print(f"Accuracy {accuracy}")


In [None]:
torch.save(model.state_dict(), 'generative_approach.model')

In [None]:
# make class predictions with the model
predictions = (model(X) > 0.5).int()
for i in range(100):
    print('%s' % (y[i].tolist()))
    #print('%s => %d (expected %d)' % (X[i].tolist(), predictions[i], y[i]))
    #print('%s => %d (expected %s)' % (X[i].tolist(), predictions[i], y[i].tolist()))
    #print('%s => %d (expected %d)' % (X[i].tolist(), predictions[i], y[i]))

print("done")