In [1]:
import torch
from torch import nn
from torch import optim
import random
random.seed(42)

In [24]:
def in_mandelbrot(c, max_iter = 10000):
    z = 0
    for i in range(max_iter):
        if z.real**2+z.imag**2 >= 4: return False
        z = z**2 + c
    return True

def data_fn(max_val = 2, size = 1000):
    x = []
    y = []
    for i in range(size):
        c = complex(random.uniform(-max_val, max_val), random.uniform(-max_val, max_val))
        x.append([c.real, c.imag])
        y.append([1.0 if in_mandelbrot(c) else 0.0])
    return torch.tensor(x), torch.tensor(y)

x, y = data_fn()
print(f'x.shape: {x.shape}')
print(f'y.shape: {y.shape}')

x.shape: torch.Size([1000, 2])
y.shape: torch.Size([1000, 1])


In [27]:
device = "cuda" if torch.cuda.is_available() else "cpu"

def train(data_fn, model, loss_fn, optimizer, iterations):
    model.train()
    for i in range(iterations):
        x, y = data_fn()
        x, y = x.to(device), y.to(device)
        
        pred = model(x)
        loss = loss_fn(pred, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if i % 100 == 0:
            acc = (pred == y).sum().item() / len(y)
            # note we don't need to look at the validation metrics since each training data set
            # is randomly generated
            print(f'iteration {i}: loss {loss.item()} accuracy: {acc}')
            
def test(data, model, loss_fn):
    model.eval()
    with torch.no_grad():
        x, y = data
        x, y = x.to(device), y.to(device)

        pred = model(x)
        loss = loss_fn(pred, y)
        return loss.item()



In [28]:
model = nn.Sequential(
    nn.Linear(2, 64),
    nn.ReLU(),
    nn.Linear(64, 32),
    nn.ReLU(),
    nn.Linear(32, 16),
    nn.ReLU(),
    nn.Linear(16, 8),
    nn.ReLU(),
    nn.Linear(8, 1),
    nn.Sigmoid()
)
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters())
train(data_fn, model, loss_fn, optimizer, 1000)

iteration 0: loss 0.808846652507782 validation loss 0.8085283637046814
iteration 100: loss 0.19116398692131042 validation loss 0.21022644639015198
iteration 200: loss 0.13568860292434692 validation loss 0.12285153567790985
iteration 300: loss 0.09775923192501068 validation loss 0.0988861545920372
iteration 400: loss 0.06899090111255646 validation loss 0.07769561558961868
iteration 500: loss 0.05850527435541153 validation loss 0.03515288606286049
iteration 600: loss 0.03838123381137848 validation loss 0.0355440229177475
iteration 700: loss 0.02677726000547409 validation loss 0.023957708850502968
iteration 800: loss 0.0252367090433836 validation loss 0.037648171186447144
iteration 900: loss 0.02494014985859394 validation loss 0.033861760050058365
