In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from malib.train import train

import matplotlib.pyplot as plt

In [17]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class SpiralDataset(Dataset):
    def __init__(self, N, K):
        self.X, self.y = self.generate_spiral_dataset(N, K)
        
        # Convert numpy arrays to PyTorch tensors
        self.X = torch.from_numpy(self.X).float()
        self.y = torch.from_numpy(self.y).long()

    def generate_spiral_dataset(self, N, K):
        X = np.zeros((N*K, 2)) # data matrix (each row = single example)
        y = np.zeros(N*K, dtype='uint8') # class labels

        for j in range(K):
            ix = range(N*j, N*(j+1))
            r = np.linspace(0.0, 1, N) # rayon
            t = np.linspace(j*4, (j+1)*4, N) + np.random.randn(N)*0.2 # theta
            X[ix] = np.c_[r*np.sin(t), r*np.cos(t)]
            y[ix] = j

        return X, y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

N = 100 # number of points per class
K = 3 # number of classes

spiral_dataset_train = SpiralDataset(N, K)
spiral_dataset_test = SpiralDataset(N, K)

train_loader = DataLoader(spiral_dataset_train, batch_size=3*N, shuffle=True)
test_loader = DataLoader(spiral_dataset_train, batch_size=3*N, shuffle=False)

In [3]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()

        self.fc1 = nn.Linear(2, 10)
        self.fc2 = nn.Linear(10, 10)
        self.fc3 = nn.Linear(10, 3)

    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        
        return x

In [4]:
model = MLP()

In [6]:
train(model, train_loader, epochs=500)

Epoch 0/500. Loss=0.776606023311615
Epoch 1/500. Loss=0.7697689533233643
Epoch 2/500. Loss=0.7627350091934204
Epoch 3/500. Loss=0.7559328675270081
Epoch 4/500. Loss=0.7493581771850586
Epoch 5/500. Loss=0.7428911924362183
Epoch 6/500. Loss=0.7364201545715332
Epoch 7/500. Loss=0.7299900650978088
Epoch 8/500. Loss=0.7236257195472717
Epoch 9/500. Loss=0.7173200845718384
Epoch 10/500. Loss=0.7109372615814209
Epoch 11/500. Loss=0.7042123675346375
Epoch 12/500. Loss=0.6971761584281921
Epoch 13/500. Loss=0.6898736357688904
Epoch 14/500. Loss=0.6825350522994995
Epoch 15/500. Loss=0.675251305103302
Epoch 16/500. Loss=0.6681879162788391
Epoch 17/500. Loss=0.6615259051322937
Epoch 18/500. Loss=0.6550375819206238
Epoch 19/500. Loss=0.6485284566879272
Epoch 20/500. Loss=0.6419057250022888
Epoch 21/500. Loss=0.6351152062416077
Epoch 22/500. Loss=0.6281653642654419
Epoch 23/500. Loss=0.6211393475532532
Epoch 24/500. Loss=0.6140297055244446
Epoch 25/500. Loss=0.6069065928459167
Epoch 26/500. Loss=0.599