In [None]:
!git clone https://github.com/google-research/torchsde.git

Cloning into 'torchsde'...
remote: Enumerating objects: 1620, done.[K
remote: Counting objects: 100% (190/190), done.[K
remote: Compressing objects: 100% (125/125), done.[K
remote: Total 1620 (delta 103), reused 115 (delta 56), pack-reused 1430 (from 1)[K
Receiving objects: 100% (1620/1620), 4.31 MiB | 24.54 MiB/s, done.
Resolving deltas: 100% (1107/1107), done.


In [None]:
import os
os.chdir('./torchsde')
print(os.getcwd())

/content/torchsde


In [None]:
!ls

assets	    CONTRIBUTING.md  DOCUMENTATION.md  LICENSE	       README.md  tests
benchmarks  diagnostics      examples	       pyproject.toml  setup.py   torchsde


In [None]:
!pip install torchsde

Collecting torchsde
  Downloading torchsde-0.2.6-py3-none-any.whl.metadata (5.3 kB)
Collecting trampoline>=0.1.2 (from torchsde)
  Downloading trampoline-0.1.2-py3-none-any.whl.metadata (10 kB)
Downloading torchsde-0.2.6-py3-none-any.whl (61 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.2/61.2 kB[0m [31m5.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading trampoline-0.1.2-py3-none-any.whl (5.2 kB)
Installing collected packages: trampoline, torchsde
Successfully installed torchsde-0.2.6 trampoline-0.1.2
Collecting fire
  Downloading fire-0.7.0.tar.gz (87 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.7.0-py3-none-any.whl size=114249 sha256=cbf32ca353e81f946a3140271fe4a1dc4c332b0321cabee4

In [None]:
########### MNIST Dataset Traning Experiment ##############

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split
from torchvision import datasets, transforms
import torchsde
from tqdm import tqdm
import numpy as np

class MNISTSDEModel(torchsde.SDEIto):
    def __init__(self, data_dim, hidden_dim=64, noise_type='diagonal'):
        super().__init__(noise_type=noise_type)

        self.fc1 = nn.Linear(data_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, data_dim)
        self.sigma = nn.Parameter(torch.ones(data_dim))

    def f(self, t, x):
        # Drift function
        h = torch.relu(self.fc1(x))
        h = torch.relu(self.fc2(h))
        return self.fc3(h)

    def g(self, t, x):
        # Diffusion function
        return self.sigma * torch.ones_like(x)

    def h(self, t, x):
        # Aux function for SDE
        return torch.zeros_like(x)

    def sample(self, batch_size, x0, t, dt):

        num_steps = int(t / dt)
        xs = [x0]
        for _ in range(num_steps):
            t = torch.tensor([0.0, dt])
            x = torchsde.sdeint(self, xs[-1], t, dt=dt)[-1]
            xs.append(x)
        return torch.stack(xs)

def train_mnist_sde(epochs=10, batch_size=128, learning_rate=0.001, validation_split=0.5):
    # Normalizing data for our model
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Load MNIST Dataset
    full_dataset = datasets.MNIST(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )

    # Split the training and test dataset equally
    train_size = int(len(full_dataset) * (1 - validation_split))
    test_size = len(full_dataset) - train_size
    train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False
    )

    input_dim = 28 * 28
    model = MNISTSDEModel(data_dim=input_dim)

    classifier = nn.Sequential(
        nn.Linear(input_dim, 128),
        nn.ReLU(),
        nn.Linear(128, 10)
    )

    # Loss and Optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(list(model.parameters()) + list(classifier.parameters()), lr=learning_rate)

    train_losses = []
    test_accuracies = []

    for epoch in range(epochs):

        model.train()
        total_train_loss = 0.0

        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}',
                          postfix={'train_loss': 0.0, 'test_accuracy': 0.0})

        for batch_idx, (data, target) in enumerate(train_pbar):

            data = data.view(data.size(0), -1)

            batch_size = data.size(0)
            t0, t1 = 0.0, 1.0
            ts = torch.tensor([t0, t1])
            xs = torchsde.sdeint(model, data, ts, dt=0.1)

            outputs = classifier(xs[-1])
            loss = criterion(outputs, target)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

            train_pbar.set_postfix({
                'train_loss': f'{loss.item():.4f}'
            })

        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)

        # Validation
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            test_pbar = tqdm(test_loader, desc='Testing', leave=False)
            for data, target in test_pbar:
                data = data.view(data.size(0), -1)
                xs = model.sample(data.size(0), x0=data, t=t0, dt=0.1)
                outputs = classifier(xs[-1])

                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()

                test_pbar.set_postfix({'accuracy': f'{100 * correct / total:.2f}%'})

        test_accuracy = 100 * correct / total
        test_accuracies.append(test_accuracy)

        train_pbar.close()

        print(f'Epoch {epoch+1}:')
        print(f'  Train Loss: {avg_train_loss:.4f}')
        print(f'  Test Accuracy: {test_accuracy:.2f}%')

    print(f"Best Test Accuracy: {max(test_accuracies):.2f}%")

    return model, classifier, train_losses, test_accuracies

if __name__ == '__main__':

    torch.manual_seed(42)
    np.random.seed(42)

    model, classifier, train_losses, test_accuracies = train_mnist_sde(epochs=10)

Epoch 1/10: 100%|██████████| 235/235 [01:40<00:00,  2.33it/s, train_loss=0.3029]


Epoch 1:
  Train Loss: 0.4158
  Test Accuracy: 79.71%


Epoch 2/10: 100%|██████████| 235/235 [01:41<00:00,  2.32it/s, train_loss=0.1341]


Epoch 2:
  Train Loss: 0.1839
  Test Accuracy: 73.89%


Epoch 3/10: 100%|██████████| 235/235 [01:42<00:00,  2.29it/s, train_loss=0.1850]


Epoch 3:
  Train Loss: 0.1302
  Test Accuracy: 69.64%


Epoch 4/10: 100%|██████████| 235/235 [01:41<00:00,  2.32it/s, train_loss=0.0815]


Epoch 4:
  Train Loss: 0.1089
  Test Accuracy: 69.68%


Epoch 5/10: 100%|██████████| 235/235 [01:44<00:00,  2.26it/s, train_loss=0.0733]


Epoch 5:
  Train Loss: 0.0896
  Test Accuracy: 65.51%


Epoch 6/10: 100%|██████████| 235/235 [01:38<00:00,  2.40it/s, train_loss=0.0703]


Epoch 6:
  Train Loss: 0.0719
  Test Accuracy: 61.43%


Epoch 7/10: 100%|██████████| 235/235 [01:41<00:00,  2.32it/s, train_loss=0.0497]


Epoch 7:
  Train Loss: 0.0617
  Test Accuracy: 63.94%


Epoch 8/10: 100%|██████████| 235/235 [01:42<00:00,  2.29it/s, train_loss=0.0123]


Epoch 8:
  Train Loss: 0.0537
  Test Accuracy: 69.99%


Epoch 9/10: 100%|██████████| 235/235 [01:43<00:00,  2.28it/s, train_loss=0.0542]


Epoch 9:
  Train Loss: 0.0498
  Test Accuracy: 64.37%


Epoch 10/10: 100%|██████████| 235/235 [01:41<00:00,  2.31it/s, train_loss=0.0847]


Epoch 10:
  Train Loss: 0.0433
  Test Accuracy: 63.69%

Training Complete!
Best Test Accuracy: 79.71%
