In [71]:
import torch
import torch.nn as nn

from panther.nn import SKConv2d


In [72]:
def make_sketched_conv2d(
    layer: nn.Conv2d,
    num_terms: int = 6,
    low_rank: int = 8,
) -> SKConv2d:
    """Creates a SKConv2d layer from a given layer."""
    assert isinstance(layer, nn.Conv2d), "Layer must be a Conv2d layer"
    assert layer.groups == 1, "Groups must be 1 for SKConv2d"
    assert layer.dilation == (1, 1), "Dilation must be (1, 1) for SKConv2d"

    sketched_conv = SKConv2d(
        in_channels=layer.in_channels,
        out_channels=layer.out_channels,
        kernel_size=layer.kernel_size,
        stride=layer.stride,
        padding=layer.padding,
        # bias=layer.bias is not None,
        num_terms=num_terms,
        low_rank=low_rank,
    )

    kernels = layer.weight.data.clone()
    kernels = kernels.permute(1, 2, 3, 0)

    def mode4_unfold(tensor: torch.Tensor) -> torch.Tensor:
        """Computes mode-4 matricization (unfolding along the last dimension)."""
        return tensor.reshape(-1, tensor.shape[-1])  # (I4, I1 * I2 * I3)

    sketched_conv.S1s = nn.Parameter(
        torch.stack(
            [
                mode4_unfold(torch.matmul(kernels, sketched_conv.U1s[i].T))
                for i in range(num_terms)
            ]
        )
    )  # d2xk
    K_mat4 = kernels.view(
        layer.in_channels * sketched_conv.kernel_size[0] * sketched_conv.kernel_size[1],
        layer.out_channels,
    )
    sketched_conv.S2s = nn.Parameter(
        torch.stack(
            [
                mode4_unfold(
                    torch.matmul(sketched_conv.U2s[i], K_mat4).view(
                        low_rank, *sketched_conv.kernel_size, layer.out_channels
                    )
                )
                for i in range(num_terms)
            ]
        )
    )
    if layer.bias is not None:
        sketched_conv.bias = nn.Parameter(layer.bias.data.clone())
    else:
        raise ValueError(
            "Layer must have a bias parameter, not implemented yet to not have it sketchedconv2d"
        )

    return sketched_conv

In [73]:
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc = nn.Linear(32 * 32 * 32, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.conv2(x)
        x = torch.relu(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x


class sketched_cnn(nn.Module):
    def __init__(self):
        super(sketched_cnn, self).__init__()
        self.sketch_conv1 = make_sketched_conv2d(
            nn.Conv2d(3, 16, kernel_size=3, padding=1), num_terms=6, low_rank=8
        )
        self.sketch_conv2 = make_sketched_conv2d(
            nn.Conv2d(16, 32, kernel_size=3, padding=1), num_terms=6, low_rank=8
        )
        self.fc = nn.Linear(32 * 8 * 8, 10)

    def forward(self, x):
        x = self.sketch_conv1(x)
        x = torch.relu(x)
        x = self.sketch_conv2(x)
        x = torch.relu(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc(x)
        return x


In [74]:
def gen_data_with_labels(batch_size=300, num_channels=3, height=32, width=32):
    """Generates random data for testing."""
    return torch.randn(batch_size, num_channels, height, width), torch.randint(
        0, 10, (batch_size,)
    )


def train(model, data, targets, epochs=3, lr=0.001):
    """Trains the model."""
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")


def test(model, data, targets):
    """Tests the model."""
    with torch.no_grad():
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total = targets.size(0)
        correct = (predicted == targets).sum().item()
        print(f"Accuracy: {100 * correct / total:.2f}%")

In [75]:
model = cnn()
model2 = sketched_cnn()
data, targets = gen_data_with_labels()
train(model, data, targets, epochs=1, lr=0.001)
test(model, data, targets)
model2.fc = model.fc
model2.sketch_conv1 = make_sketched_conv2d(model.conv1, num_terms=6, low_rank=8)
model2.sketch_conv2 = make_sketched_conv2d(model.conv2, num_terms=6, low_rank=8)
test(model2, data, targets)

Epoch [1/1], Loss: 2.3012
Accuracy: 29.33%
Accuracy: 27.33%
