In [None]:
import import_ipynb
import torch
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np
from PIL import Image
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.optim import Adam

from sampler import *

In [None]:
n = 1000

In [None]:
local = sample_no_signalling(n, True)
non_local = sample_no_signalling(n, False)

In [None]:
data = local + non_local
targets = [0] * n + [1] * n

In [None]:
class MyDataset(Dataset):
    def __init__(self, data, targets, transform):
        self.data = data
        self.targets = targets
        self.transform = transform
        
    def __getitem__(self, idx):
        x = Image.fromarray(self.data[idx].astype(np.uint8))
        x = self.transform(x)
        #x = transforms.ToTensor(self.data[idx])
        y = self.targets[idx]
        return x, y
    
    def __len__(self):
        return len(self.targets)

In [None]:
class MyDataModule(pl.LightningDataModule):
    def __init__(self, X, y, transform):
        super().__init__()
        self.X = X
        self.y = y
        self.transform = transform
    
    def setup(self, stage):
        X_train, X_test, y_train, y_test = train_test_split(self.X, self.y)
        self.train_dataset = MyDataset(X_train, y_train, transform)
        self.test_dataset = MyDataset(X_test, y_test, transform)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=10, shuffle=True)
    
    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=10, shuffle=True)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(150544, 120)
        self.fc2 = nn.Linear(120, 1)
        self.fc3 = nn.Sigmoid()

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
class Classifier(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = Net()

    def forward(self, x):
        return self.model(x)
    
    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)
        
    def training_step(self, train_batch, batch_idx):
        x, y = train_batch
        y_hat = self.model(x)
        y = y.unsqueeze(1)
        y = y.float()
        loss = F.binary_cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y = val_batch
        x = x.view(x.size(0), -1)
        y_hat = self.model(x)
        y = y.unsqueeze(1)
        y = y.float()
        loss = F.binary_cross_entropy(y_hat, y)
        val_loss.append(loss)
        self.log('val loss', loss)
        return val_loss

In [None]:
from pl_bolts.callbacks import PrintTableMetricsCallback

transform = transforms.Compose([transforms.ToTensor()])
dm = MyDataModule(data, targets, transform)
clf = Classifier() 
trainer = pl.Trainer(max_epochs=10, callbacks=[PrintTableMetricsCallback()])
trainer.fit(clf, dm)