In [1]:
import torch
from torch import autograd
from torch import nn
import torchvision
from torch import optim
import torchvision.transforms as transforms

In [84]:
class SLinearFunction(autograd.Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, inputS, weight, weightS, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        # outputS = inputS.mm(weightS.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output, torch.ones_like(output)#outputS

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output, grad_outputS):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_inputS = grad_bias = grad_weightS = None

        # print(f"g: {grad_output}")
        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_inputS = grad_outputS.mm(weight**2)
        if ctx.needs_input_grad[2]:
            grad_weight = grad_output.t().mm(input)
        if ctx.needs_input_grad[3]:
            grad_weightS = grad_outputS.t().mm(input**2)
        if bias is not None and ctx.needs_input_grad[4]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_inputS, grad_weight, grad_weightS, grad_bias

class SMSEFunction(autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, inputS, target, size_average=None, reduce=None, reduction='mean'):
        function = torch.nn.functional.mse_loss
        output = function(input, target, size_average, reduce, reduction)
        ctx.save_for_backward(input, target)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, target = ctx.saved_tensors
        grad_input = 2 * (input - target)

        return grad_input, torch.ones_like(grad_input) * 2, None, None, None, None

def is_nan(x):
    return torch.isnan(x).sum() != 0

def nan_print(x):
    x = x.tolist()
    for i in x:
        print(i)

def test_nan(exp, exp_sum, g_input, g_inputS, ratio):
    if is_nan(g_input) or is_nan(g_inputS):
        torch.save([exp.cpu().numpy(), exp_sum.cpu().numpy()], "debug.pt")
        print(is_nan(g_input), is_nan(g_inputS))
        raise Exception

class SCrossEntropyLossFunction(autograd.Function):
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, inputS, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'):
        function = torch.nn.functional.cross_entropy
        output = function(input, target, weight, size_average, ignore_index, reduce, reduction)
        ctx.save_for_backward(input, target)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        eps = pow(2,-10)
        input, target = ctx.saved_tensors

        the_max = torch.max(input, dim=1)[0].unsqueeze(1).expand_as(input)
        exp = torch.exp(input - the_max)
        exp_sum = exp.sum(dim=1).unsqueeze(1).expand_as(input) + eps
        ratio = exp / exp_sum

        grad_input_mask = torch.zeros_like(input)
        l_index = torch.LongTensor(range(len(input))).to(grad_input_mask.device)
        grad_input_mask[l_index, target] = 1
        grad_input = (ratio - grad_input_mask)/len(input)
        grad_inputS = (exp_sum - exp) * exp / (exp_sum ** 2)
        # grad_input = (ratio - grad_input_mask)/len(input)
        grad_inputS = (1 - ratio) * ratio
        
        test_nan(exp, exp_sum, grad_input, grad_inputS, ratio)

        return grad_input, grad_inputS, None, None, None, None

In [93]:
class SLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super().__init__()
        self.op = nn.Linear(in_features, out_features, bias)
        self.weightS = nn.Parameter(torch.ones(self.op.weight.size()).requires_grad_())
        # self.weightS.retain_grad()
        self.function = SLinearFunction.apply
    
    def push_S_device(self):
        self.weightS = self.weightS.to(self.op.weight.device)

    def clear_S_grad(self):
        with torch.no_grad():
            if self.weightS.grad is not None:
                self.weightS.grad.data *= 0
    
    def fetch_S_grad(self):
        return self.weightS.grad.sum()

    def do_second(self):
        self.op.weight.grad.data = self.op.weight.grad.data / (self.weightS.grad.data + 1e-10)

    def forward(self, x, xS):
        x, xS = self.function(x, xS, self.op.weight, self.weightS, self.op.bias)
        return x, xS

class SReLU(nn.Module):
    def __init__(self):
        super().__init__()
        self.relu = nn.ReLU()
    
    def forward(self, x, xS):
        return self.relu(x), self.relu(xS)

class SModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = SLinear(28*28,32)
        self.fc2 = SLinear(32,32)
        self.fc3 = SLinear(32,10)
        self.relu = SReLU()
    
    def push_S_device(self):
        for m in self.modules():
            if isinstance(m, SLinear):
                m.push_S_device()

    def clear_S_grad(self):
        for m in self.modules():
            if isinstance(m, SLinear):
                m.clear_S_grad()

    def do_second(self):
        for m in self.modules():
            if isinstance(m, SLinear):
                m.do_second()

    def fetch_S_grad(self):
        S_grad_sum = 0
        for m in self.modules():
            if isinstance(m, SLinear):
                S_grad_sum += m.fetch_S_grad()
        return S_grad_sum
    


    def forward(self, x):
        xS = torch.zeros_like(x)
        x, xS = self.fc1(x, xS)
        x, xS = self.relu(x, xS)
        x, xS = self.fc2(x, xS)
        x, xS = self.relu(x, xS)
        x, xS = self.fc3(x, xS)
        return x, xS

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28*28,32)
        self.fc2 = nn.Linear(32,32)
        self.fc3 = nn.Linear(32,10)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [94]:
def eval():
    total = 0
    correct = 0
    with torch.no_grad():
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 784)
            outputs = model(images)
            predictions = outputs.argmax(dim=1)
            correction = predictions == labels
            correct += correction.sum()
            total += len(correction)
    return correct/total

def Seval():
    total = 0
    correct = 0
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 784)
            outputs = model(images)
            predictions = outputs[0].argmax(dim=1)
            correction = predictions == labels
            correct += correction.sum()
            total += len(correction)
    return correct/total

In [95]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

BS = 128

trainset = torchvision.datasets.MNIST(root='~/testCode/data', train=True,
                                        download=False, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BS,
                                        shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='~/testCode/data', train=False,
                                    download=False, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=BS,
                                            shuffle=False, num_workers=2)
# model = Model()
# model.to(device)
# criteria = nn.CrossEntropyLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

# for _ in range(10):
#     for images, labels in trainloader:
#         optimizer.zero_grad()
#         images, labels = images.to(device), labels.to(device)
#         images = images.view(-1, 784)
#         outputs = model(images)
#         loss = criteria(outputs, labels)
#         loss.backward()
#         optimizer.step()
#     print(f"test acc: {eval():.4f}")

In [96]:
device = torch.device("cuda:0")
model = SModel()
model.to(device)
model.push_S_device()
criteria = SCrossEntropyLossFunction.apply
optimizer = optim.Adam(model.parameters(), lr=0.001)

for _ in range(100):
    running_loss = 0.
    running_l = 0.
    for i, (images, labels) in enumerate(trainloader):
        optimizer.zero_grad()
        model.clear_S_grad()
        images, labels = images.to(device), labels.to(device)
        images = images.view(-1, 784)
        outputs, outputsS = model(images)
        loss = criteria(outputs, outputsS,labels)
        loss.backward()
        l = loss + model.fetch_S_grad()
        model.do_second()
        optimizer.step()
        running_loss += loss.item()
        running_l += l.item()
        # print(running_loss)
    print(f"test acc: {Seval():.4f}, loss: {running_loss / len(trainloader)}, s: {(running_l - running_loss) / len(trainloader)}")

test acc: 0.0980, loss: 533.2699938369458, s: 2171325.3065260295
test acc: 0.0958, loss: 479.8668044035369, s: 121462.35733022467


KeyboardInterrupt: 

In [None]:
device = torch.device("cuda")
model = SModel()
model.to(device)
model.push_S_device()
criteria = SCrossEntropyLossFunction.apply
optimizer = optim.Adam(model.parameters(), lr=0.001)

for _ in range(100):
    running_loss = 0.
    running_l = 0.
    for images, labels in trainloader:
        optimizer.zero_grad()
        images, labels = images.to(device), labels.to(device)
        images = images.view(-1, 784)
        outputs, outputsS = model(images)
        loss = criteria(outputs, outputsS,labels)
        loss.backward()
        l = loss + model.fetch_S_grad()
        optimizer.step()
        optimizer.zero_grad()
        model.clear_S_grad()
        running_loss += loss.item()
        running_l += l.item()
    print(f"test acc: {Seval():.4f}, loss: {running_loss / len(trainloader)}, s: {(running_l - running_loss) / len(trainloader)}")

In [83]:
a = torch.Tensor([[1,2,3],[1,2,3]])
torch.exp(a - torch.max(a, dim=1)[0])

RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

In [73]:
a

tensor([[1., 2.],
        [1., 2.]])

In [76]:
torch.max(a, dim=1)

torch.return_types.max(
values=tensor([2., 2.]),
indices=tensor([1, 1]))