In [21]:
import torch
import torch.nn as nn

from panther.nn import SKConv2d


In [22]:
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc = nn.Linear(32 * 32 * 32, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x


class sketched_cnn(nn.Module):
    def __init__(self):
        super(sketched_cnn, self).__init__()
        self.sketch_conv1 = SKConv2d.fromTorch(
            nn.Conv2d(3, 16, kernel_size=3, padding=1), num_terms=6, low_rank=8
        )
        self.sketch_conv2 = SKConv2d.fromTorch(
            nn.Conv2d(16, 32, kernel_size=3, padding=1), num_terms=6, low_rank=8
        )
        self.fc = nn.Linear(32 * 8 * 8, 10)

    def forward(self, x):
        x = self.sketch_conv1(x)
        x = torch.relu(x)
        x = self.sketch_conv2(x)
        x = torch.relu(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x


In [23]:
def gen_data_with_labels(batch_size=300, num_channels=3, height=32, width=32):
    """Generates random data for testing."""
    return torch.randn(batch_size, num_channels, height, width), torch.randint(
        0, 10, (batch_size,)
    )


def train(model, data, targets, epochs=3, lr=0.001):
    """Trains the model."""
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")


def test(model, data, targets):
    """Tests the model."""
    with torch.no_grad():
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total = targets.size(0)
        correct = (predicted == targets).sum().item()
        print(f"Accuracy: {100 * correct / total:.2f}%")

In [24]:
model = cnn()
model2 = sketched_cnn()
data, targets = gen_data_with_labels()
train(model, data, targets, epochs=1, lr=0.001)
test(model, data, targets)
model2.fc = model.fc
model2.sketch_conv1 = SKConv2d.fromTorch(model.conv1, num_terms=6, low_rank=8)
model2.sketch_conv2 = SKConv2d.fromTorch(model.conv2, num_terms=6, low_rank=8)
test(model2, data, targets)

Epoch [1/1], Loss: 2.2982
Accuracy: 22.00%
Accuracy: 13.33%
