#Building and training Neural Networks with PyTorch

A neural network is a subset of machine learning that uses the interconnected layers of nodes to process the data and find patterns.

PyTorch is a Deep-Learning framework that allows us to do it.

Torch.nn module is the collection that includes various pre-defined layers, activation functions, loss functions and utilities for building and training neural networks.

In [1]:
# define the Neural Network Class
import torch
import torch.nn as nn

class SimpleRNN(nn.Module):
  def __init__(self): #constructor
    super(SimpleRNN,self).__init__()
    self.fc1=nn.Linear(2,4)
    self.fc2=nn.Linear(4,1)
  def forward(self,x):
     x= torch.relu(self.fc1(x))
     x= self.fc2(x)
     return x

In [7]:
#prepare the Data using simple dataset representing XOR logic gate.
X_train=torch.tensor([[0,0],[0,1],[1,0],[1,1]]).float()
Y_train=torch.tensor([[0],[1],[1],[0]]).float()

In [8]:
# Instantiate the model, Define loss function and Optimizer.
import torch.optim as optim

model=SimpleRNN()
loss_func=nn.MSELoss()
optimizer=optim.SGD(model.parameters(),lr=0.1)

In [13]:
#Training the Model'
for epoch in range(100):
    model.train()

    # Forward pass
    outputs = model(X_train)
    loss = loss_func(outputs, Y_train)

    # Backward pass and optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/10], Loss: {loss.item():.4f}')

Epoch [10/10], Loss: 0.2500
Epoch [20/10], Loss: 0.2500
Epoch [30/10], Loss: 0.2500
Epoch [40/10], Loss: 0.2500
Epoch [50/10], Loss: 0.2500
Epoch [60/10], Loss: 0.2500
Epoch [70/10], Loss: 0.2500
Epoch [80/10], Loss: 0.2500
Epoch [90/10], Loss: 0.2500
Epoch [100/10], Loss: 0.2500


In [11]:
#testing the model
model.eval()
with torch.no_grad():
    test_data = torch.tensor([[0.0, 0.0], [0.0, 1.0], [1.0, 0.0], [1.0, 1.0]])
    predictions = model(test_data)
    print(f'Predictions:\n{predictions}')

Predictions:
tensor([[0.5080],
        [0.5005],
        [0.5005],
        [0.4930]])
