In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models


In [11]:
class RLRRLinear(nn.Module):
    def __init__(self,in_feature,out_feature,bias=True):
        super().__init__()
        self.register_buffer('weight',torch.empty(out_feature,in_feature))
        if bias:
            self.register_buffer('bias',torch.empty(out_feature))
        else:
            self.register_buffer('bias',None)

        self.s_left = nn.Parameter(torch.zeros(out_feature))
        self.s_right = nn.Parameter(torch.zeros(in_feature))
        self.f = nn.Parameter(torch.zeros(out_feature))

    def forward(self,x):
        scaling = 1 + self.s_left.unsqueeze(1) * self.s_right.unsqueeze(0)
        w_re = scaling * self.weight
        if self.bias is not None:
            b_re = self.bias + self.f
        else:
            b_re = None
        return F.linear(x,w_re,b_re)

In [12]:
class ResNet(nn.Module):
    def __init__(self,num_classes = 1000):
        super().__init__()
        pretrained = models.resnet50(pretrained=True)
        self.resnet = pretrained

        in_feature = self.resnet.fc.in_features
        original_weight = self.resnet.fc.weight.data.clone()
        if self.resnet.fc.bias is not None:
            original_bias = self.resnet.fc.bias.data.clone()
        else:
            original_bias = None

        self.resnet.fc = RLRRLinear(in_feature,num_classes,bias=True)
        self.resnet.fc.weight.copy_(original_weight)
        if original_bias is not None:
            self.resnet.fc.bias.copy_(original_bias)

        for name,param in self.resnet.named_parameters():
            if "fc" not in name:
                param.requires_grad = False

    def forward(self,x):
        return self.resnet(x)

In [13]:
if __name__ == "__main__":
    i = torch.randn(8,3,224,224)
    model = ResNet(num_classes=1000)
    logits = model(i)
    print(logits.shape)
    optimizer = torch.optim.Adam(filter(lambda p:p.requires_grad,model.parameters()),lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    

torch.Size([8, 1000])


In [1]:
import torch
print(torch.cuda.is_available())

True
