In [30]:
import torch
from torch import nn
from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"


In [43]:
X_train = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]]).float().to(device)
y_train = torch.tensor([[0], [1], [1], [0]]).float().to(device)
X_test = torch.tensor([[0,1], [1, 1], [1, 0], [0, 0]]).float().to(device)
y_test = torch.tensor([[1], [0], [1], [0]]).float().to(device)

In [89]:
class XORRNN(nn.Module):
    def __init__(self, input, hidden, output) -> None:
        super().__init__()
        self.fc1 = nn.Linear(input, hidden)
        self.act = nn.ReLU()
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, output)
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.act(out)
        out = self.fc2(out)
        out = self.act(out)
        out = self.fc3(out)
        return out 

model = XORRNN(2, 5, 1).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [91]:
epochs = 10000
for epoch in tqdm(range(epochs)):
    optimizer.zero_grad()
    y_pred = model(X_train)
    loss = criterion(y_pred, y_train)
    loss.backward()
    optimizer.step()
    
    if(epoch % 1000 == 999):
        model.eval()
        with torch.no_grad():
            
            correct = 0
            total = 0
            
            for i in range(len(X_test)):
                
                y_pred = model(X_test[i])
                correct += (y_pred.round() == y_test[i]).sum().item()
                total += 1
            
            print(f"Epoch {epoch}: Accuracy: {correct/total}")
            model.train()
            
    

 12%|█▏        | 1230/10000 [00:01<00:07, 1136.65it/s]

Epoch 999: Accuracy: 1.0


 21%|██▏       | 2138/10000 [00:01<00:06, 1126.80it/s]

Epoch 1999: Accuracy: 1.0


 32%|███▏      | 3169/10000 [00:02<00:06, 1136.13it/s]

Epoch 2999: Accuracy: 1.0


 42%|████▏     | 4212/10000 [00:03<00:05, 1143.11it/s]

Epoch 3999: Accuracy: 1.0


 51%|█████▏    | 5127/10000 [00:04<00:04, 1134.47it/s]

Epoch 4999: Accuracy: 1.0


 62%|██████▏   | 6152/10000 [00:05<00:03, 1136.64it/s]

Epoch 5999: Accuracy: 1.0


 72%|███████▏  | 7187/10000 [00:06<00:02, 1151.78it/s]

Epoch 6999: Accuracy: 1.0


 82%|████████▏ | 8231/10000 [00:07<00:01, 1140.10it/s]

Epoch 7999: Accuracy: 1.0


 92%|█████████▏| 9151/10000 [00:08<00:00, 1123.07it/s]

Epoch 8999: Accuracy: 1.0


100%|██████████| 10000/10000 [00:08<00:00, 1118.51it/s]

Epoch 9999: Accuracy: 1.0



