In [1]:
import torch
from sklearn.model_selection import train_test_split

In [2]:
# defining simple nn model with 2 fc layers
class SimpleNet(torch.nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = torch.nn.Linear(4,2)
        self.fc2 = torch.nn.Linear(2,1)

    def forward(self, x):
        output = torch.relu(self.fc1(x))
        output = torch.sigmoid(self.fc2(output))
        return output

In [3]:
# simple supervised learning task to classify if the sum of features is greater than 10
features = torch.randint(high=8,size=(500,4)).type(torch.FloatTensor)
labels = torch.Tensor([x>10 for x in  torch.sum(features,dim=1)]).reshape(-1,1)

In [4]:
# splitting features and labels
X=features
y=labels
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

In [5]:
# initialise nn model
model=SimpleNet()

In [6]:
# definig loss function and optimizer
criterion=torch.nn.BCELoss()
optimizer=torch.optim.Adam(model.parameters(),lr=0.001)

In [7]:
# model training
number_of_epochs=20
total,correct=0,0
for epoch in range(number_of_epochs):
    for feat, label in zip(X_train,y_train):
        optimizer.zero_grad()
        outputs = model(feat)
        outputs_class = model(feat).round()
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step() 
        total += label.size(0)
        correct += (outputs_class.eq(label).sum().item())  

    if (epoch+1)%10 == 0: 
      print(f'epoch:{epoch+1},loss = {loss.item()}, accuracy = {correct / total}')

epoch:10,loss = 1.2445353269577026, accuracy = 0.80025
epoch:20,loss = 1.0772840976715088, accuracy = 0.814375


In [8]:
# evaluating the model
correct, total, loss = 0, 0, 0.0
with torch.no_grad():
        for features,labels in zip(X_test,y_test):
                outputs = model(features)
                outputs_class = model(features).round()
                loss += criterion(outputs, labels).item()
                total += labels.size(0)
                correct += outputs_class.eq(labels).sum().item()
                # print(f'correctly detected  {outputs_class.eq(labels).sum().item()} out of {labels.size(0)}')
accuracy = correct / total
print(f'Loss: {loss}, Test Accuracy : {accuracy}')

Loss: 27.750593127217144, Test Accuracy : 0.86


In [None]:
# def freeze_network(model):
#     for name, p in model.named_parameters():
#         if "fc1" in name:
#             p.requires_grad = False

In [None]:
# freeze_network(model)


In [10]:
# saving model checkpoints
py_trace = torch.jit.trace(model,(torch.rand(1, 4)))
py_trace.save('models/py_trace_model.pt')