In [2]:
!git clone https://github.com/kymatio/kymatio.git

Cloning into 'kymatio'...
remote: Enumerating objects: 6471, done.[K
remote: Counting objects: 100% (259/259), done.[K
remote: Compressing objects: 100% (158/158), done.[K
remote: Total 6471 (delta 126), reused 178 (delta 91), pack-reused 6212 (from 1)[K
Receiving objects: 100% (6471/6471), 2.59 MiB | 29.50 MiB/s, done.
Resolving deltas: 100% (4281/4281), done.


In [3]:
pip install torch torchvision kymatio numpy

Collecting kymatio
  Downloading kymatio-0.3.0-py3-none-any.whl.metadata (9.6 kB)
Collecting appdirs (from kymatio)
  Downloading appdirs-1.4.4-py2.py3-none-any.whl.metadata (9.0 kB)
Collecting configparser (from kymatio)
  Downloading configparser-7.1.0-py3-none-any.whl.metadata (5.4 kB)
Downloading kymatio-0.3.0-py3-none-any.whl (87 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.6/87.6 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)
Downloading configparser-7.1.0-py3-none-any.whl (17 kB)
Installing collected packages: appdirs, configparser, kymatio
Successfully installed appdirs-1.4.4 configparser-7.1.0 kymatio-0.3.0


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from kymatio.torch import Scattering2D
from kymatio.scattering2d.core.scattering2d import scattering2d
import matplotlib.pyplot as plt


In [5]:
# Function to track weight changes
@torch.no_grad()
def track_weight_changes(initial_params, current_params):
    changes = 0
    for init, curr in zip(initial_params, current_params):
        changes += (init - curr).abs().mean().item()
    return changes

# Training function
def train(model, train_loader, optimizer, criterion, device):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

# Testing function
def test(model, test_loader, criterion, device):
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, accuracy


In [6]:
# Class 1: Fixed Scattering Classifier
class FixedScatteringClassifier(nn.Module):
    def __init__(self, J, shape, L, output_size=10):
        super(FixedScatteringClassifier, self).__init__()
        self.scattering = Scattering2D(J=J, shape=shape)

        dummy_input = torch.randn(1, 1, *shape)
        scattering_output = self.scattering(dummy_input)

        self.linear_in_size = scattering_output.numel()
        self.linear = nn.Linear(self.linear_in_size, output_size)

    def forward(self, x):
        x = self.scattering(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x


def main_step1():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    J = 2
    shape = (28, 28)
    L = 8

    # Initialize and train the Fixed Scattering Classifier
    model_fixed = FixedScatteringClassifier(J, shape, L).to(device)
    optimizer_fixed = optim.Adam(model_fixed.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    print("Training with fixed scattering transform...")
    for epoch in range(3):
        train(model_fixed, train_loader, optimizer_fixed, criterion, device)
        test_loss, accuracy = test(model_fixed, test_loader, criterion, device)
        print(f'Epoch {epoch+1}: Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%')

    # Save the trained model for further steps
    torch.save(model_fixed.state_dict(), 'fixed_classifier.pth')
    print("Step 1 completed: Fixed classifier training finished and saved.")

main_step1()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 100MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 18.9MB/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz





Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 86.3MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 3.05MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Training with fixed scattering transform...
Epoch 1: Test loss: 0.0002, Accuracy: 96.85%
Epoch 2: Test loss: 0.0001, Accuracy: 97.68%
Epoch 3: Test loss: 0.0001, Accuracy: 98.01%
Step 1 completed: Fixed classifier training finished and saved.


## Trainable Scattering Model

In [7]:
class ScatteringTorch2DTrainable(nn.Module):
    def __init__(self, J, shape, L=8, output_size=10, pretrained_classifier=None):
        super(ScatteringTorch2DTrainable, self).__init__()
        self.scattering = Scattering2D(J=J, shape=shape)
        self.J = J
        self.L = L
        self.pad = self.scattering.pad
        self.unpad = self.scattering.unpad
        self.backend = self.scattering.backend
        self.max_order = self.scattering.max_order
        self.out_type = self.scattering.out_type

        self.phi = self.scattering.phi.copy()
        self.psi = [p.copy() for p in self.scattering.psi]

        # Convert phi and psi levels to trainable parameters
        self.params = nn.ParameterList([nn.Parameter(torch.tensor(level).unsqueeze(-1), requires_grad=True)
                                            for level in self.phi['levels']])
        psi_params = nn.ParameterList([nn.Parameter(torch.tensor(psi_level).unsqueeze(-1), requires_grad=True)
                                            for psi in self.psi
                                            for psi_level in psi['levels']])

        self.params.extend(psi_params)

        # Load the pre-trained classifier and freeze its parameters
        if pretrained_classifier is not None:
            self.linear = pretrained_classifier.linear  # Use the pre-trained classifier's linear layer
        else:
            dummy_input = torch.randn(1, 1, *shape)
            scattering_output = self.scattering(dummy_input)
            scattering_output_size = scattering_output.numel()
            self.linear = nn.Linear(scattering_output_size, output_size)

        # Freeze the classifier parameters
        for param in self.linear.parameters():
            param.requires_grad = False  # Freeze the pre-trained classifier

    def load_filters(self):
        """ This function loads filters from the module's parameters """
        # each time scattering is run, one needs to make sure self.psi and self.phi point to
        # the correct buffers
        param_dict = dict(self.named_parameters())

        n = 0

        # Load phi levels
        phis = {k: v for k, v in self.phi.items() if k != 'levels'}
        phis['levels'] = []
        for phi_level in self.phi['levels']:
            phis['levels'].append(param_dict['params.' + str(n)])
            n += 1

        # Load psi levels
        psis = [{} for _ in range(len(self.psi))]
        for j in range(len(self.psi)):
            psis[j] = {k: v for k, v in self.psi[j].items() if k != 'levels'}
            psis[j]['levels'] = []
            for psi_level in self.psi[j]['levels']:
                psis[j]['levels'].append(
                    param_dict['params.' + str(n)]
                )
                n += 1

        return phis, psis

    def forward(self, x):
        x = x.squeeze()
        batch_shape = x.shape[:-2]
        signal_shape = x.shape[-2:]

        x = x.reshape((-1,) + signal_shape)
        phis, psis = self.load_filters()
        # Apply scattering transform using loaded filters
        S = scattering2d(x, self.pad, self.unpad, self.backend, self.J, self.L,
                         phis, psis, self.max_order, 'array')

        # Pass through the frozen classifier
        S = S.view(S.size(0), -1)  # Flatten the scattering output
        output = self.linear(S)  # Classify using the frozen classifier

        return output


In [8]:
def main_step2():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    J = 2
    shape = (28, 28)
    L = 8

    # Load the Fixed Scattering Classifier
    model_fixed = FixedScatteringClassifier(J=2, shape=(28,28), L=8)
    model_fixed.load_state_dict(torch.load('fixed_classifier.pth'))
    model_fixed.eval()

    # Initialize and train the Scattering Classifier
    model_1 = ScatteringTorch2DTrainable(J=2, shape = (28,28), L=8, pretrained_classifier=model_fixed).to(device)

    first_model = ScatteringTorch2DTrainable(J=2, shape = (28,28), L=8, pretrained_classifier=model_fixed).to(device)


    optimizer = optim.Adam(filter(lambda x: x.requires_grad, model_1.parameters()), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    changes = []


    print("Training scattering transform...")
    for epoch in range(3):
        train(model_1, train_loader, optimizer, criterion, device)
        test_loss, accuracy = test(model_1, test_loader, criterion, device)
        change = track_weight_changes(first_model.parameters(), model_1.parameters())
        print(f'Epoch {epoch+1}: Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Change: {change:.2f}')
        changes.append(change)

    # Save the trained model for further steps
    torch.save(model_1.state_dict(), 'trainable_wavelets_classifier.pth')
    print("Step 2 completed: Trainable wavelets classifier training finished and saved.")

main_step2()

  model_fixed.load_state_dict(torch.load('fixed_classifier.pth'))


Training scattering transform...
Epoch 1: Test loss: 0.0000, Accuracy: 98.48%, Change: 0.60
Epoch 2: Test loss: 0.0000, Accuracy: 98.64%, Change: 1.00
Epoch 3: Test loss: 0.0000, Accuracy: 98.71%, Change: 1.40
Step 2 completed: Trainable wavelets classifier training finished and saved.


## Trainable Random Scattering Model

In [10]:
class RandomScatteringTorch2DTrainable(nn.Module):
    def __init__(self, J, shape, L=8, output_size=10, pretrained_classifier=None):
        super(RandomScatteringTorch2DTrainable, self).__init__()
        self.scattering = Scattering2D(J=J, shape=shape)
        self.J = J
        self.L = L
        self.pad = self.scattering.pad
        self.unpad = self.scattering.unpad
        self.backend = self.scattering.backend
        self.max_order = self.scattering.max_order
        self.out_type = self.scattering.out_type

        self.phi = self.scattering.phi.copy()
        self.psi = [p.copy() for p in self.scattering.psi]

        # Convert phi and psi levels to trainable parameters
        self.params = nn.ParameterList([
            nn.Parameter(torch.randn(*level.shape).unsqueeze(-1), requires_grad=True)
                                            for level in self.phi['levels']
        ])
        psi_params = nn.ParameterList([nn.Parameter(torch.randn(*psi_level.shape).unsqueeze(-1), requires_grad=True)
                                            for psi in self.psi
                                            for psi_level in psi['levels']])

        self.params.extend(psi_params)

        # Load the pre-trained classifier and freeze its parameters
        if pretrained_classifier is not None:
            self.linear = pretrained_classifier.linear  # Use the pre-trained classifier's linear layer
        else:
            dummy_input = torch.randn(1, 1, *shape)
            scattering_output = self.scattering(dummy_input)
            scattering_output_size = scattering_output.numel()
            self.linear = nn.Linear(scattering_output_size, output_size)

        # Freeze the classifier parameters
        for param in self.linear.parameters():
            param.requires_grad = False  # Freeze the pre-trained classifier

    def load_filters(self):
        """ This function loads filters from the module's parameters """
        # each time scattering is run, one needs to make sure self.psi and self.phi point to
        # the correct buffers
        param_dict = dict(self.named_parameters())

        n = 0

        # Load phi levels
        phis = {k: v for k, v in self.phi.items() if k != 'levels'}
        phis['levels'] = []
        for phi_level in self.phi['levels']:
            phis['levels'].append(param_dict['params.' + str(n)])
            n += 1

        # Load psi levels
        psis = [{} for _ in range(len(self.psi))]
        for j in range(len(self.psi)):
            psis[j] = {k: v for k, v in self.psi[j].items() if k != 'levels'}
            psis[j]['levels'] = []
            for psi_level in self.psi[j]['levels']:
                psis[j]['levels'].append(
                    param_dict['params.' + str(n)]
                )
                n += 1

        return phis, psis

    def forward(self, x):
        x = x.squeeze()
        batch_shape = x.shape[:-2]
        signal_shape = x.shape[-2:]

        x = x.reshape((-1,) + signal_shape)
        phis, psis = self.load_filters()
        # Apply scattering transform using loaded filters
        S = scattering2d(x, self.pad, self.unpad, self.backend, self.J, self.L,
                         phis, psis, self.max_order, 'array')

        # Pass through the frozen classifier
        S = S.view(S.size(0), -1)  # Flatten the scattering output
        output = self.linear(S)  # Classify using the frozen classifier

        return output


In [26]:
def main_step3():
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    J = 2
    shape = (28, 28)
    L = 8

    # Load the Fixed Scattering Classifier
    model_fixed = FixedScatteringClassifier(J=2, shape=(28,28), L=8)
    model_fixed.load_state_dict(torch.load('fixed_classifier.pth'))
    model_fixed.eval()

    model_filters_train = ScatteringTorch2DTrainable(J=2, shape=(28,28), L=8)
    model_filters_train.load_state_dict(torch.load('trainable_wavelets_classifier.pth'))
    model_filters_train.eval()

    # Initialize and train the Randm Scattering Classifier
    model_1 = RandomScatteringTorch2DTrainable(J=2, shape = (28,28), L=8, pretrained_classifier=model_fixed).to(device)

    first_model = RandomScatteringTorch2DTrainable(J=2, shape = (28,28), L=8, pretrained_classifier=model_fixed).to(device)

    second_model = RandomScatteringTorch2DTrainable(J=2, shape = (28,28), L=8, pretrained_classifier=model_filters_train).to(device)


    optimizer = optim.Adam(filter(lambda x: x.requires_grad, model_1.parameters()), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    changes = []
    changes2 = []


    print("Training random scattering transform...")
    for epoch in range(3):
        train(model_1, train_loader, optimizer, criterion, device)
        test_loss, accuracy = test(model_1, test_loader, criterion, device)
        change = track_weight_changes(first_model.parameters(), model_1.parameters())
        change2 = track_weight_changes(second_model.parameters(), model_1.parameters())
        print(f'Epoch {epoch+1}: Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Change: {change:.2f}')
        print(f'Epoch {epoch+1}: Test loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Change: {change2:.2f}')
        changes.append(change)
        changes2.append(change2)

    # Save the trained model for further steps
    torch.save(model_1.state_dict(), 'random_wavelets_classifier.pth')
    print("Step 3 completed: random wavelets classifier training finished and saved.")

main_step3()

  model_fixed.load_state_dict(torch.load('fixed_classifier.pth'))
  model_filters_train.load_state_dict(torch.load('trainable_wavelets_classifier.pth'))


Training random scattering transform...
Epoch 1: Test loss: 0.0006, Accuracy: 82.19%, Change: 20.20
Epoch 1: Test loss: 0.0006, Accuracy: 82.19%, Change: 19.79
Epoch 2: Test loss: 0.0003, Accuracy: 91.20%, Change: 20.13
Epoch 2: Test loss: 0.0003, Accuracy: 91.20%, Change: 19.73
Epoch 3: Test loss: 0.0002, Accuracy: 93.97%, Change: 20.08
Epoch 3: Test loss: 0.0002, Accuracy: 93.97%, Change: 19.70
Step 3 completed: random wavelets classifier training finished and saved.
