In [1]:
import torch
from torch import autograd
from torch import nn
import torchvision
from torch import optim
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
import numpy as np
from models import TCrossEntropyLoss, SCrossEntropyLoss, OCrossEntropyLoss, SMLP3, SMLP4
from Functions import SCrossEntropyLossFunction
from tqdm.notebook import tqdm

In [14]:
def eval():
    total = 0
    correct = 0
    model.clear_noise()
    model.clear_mask()
    with torch.no_grad():
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 784)
            outputs = model(images)
            predictions = outputs.argmax(dim=1)
            correction = predictions == labels
            correct += correction.sum()
            total += len(correction)
    return (correct/total).item()

def Seval(is_clear_mask=True):
    total = 0
    correct = 0
    with torch.no_grad():
        model.clear_noise()
        if is_clear_mask:
            model.clear_mask()
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 784)
            outputs = model(images)
            predictions = outputs[0].argmax(dim=1)
            correction = predictions == labels
            correct += correction.sum()
            total += len(correction)
    return (correct/total).item()

def Seval_noise(var, is_clear_mask=True):
    total = 0
    correct = 0
    model.clear_noise()
    if is_clear_mask:
        model.clear_mask()
    with torch.no_grad():
        model.set_noise(var)
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 784)
            outputs = model(images)
            predictions = outputs[0].argmax(dim=1)
            correction = predictions == labels
            correct += correction.sum()
            total += len(correction)
    return (correct/total).item()

def STrain(epochs):
    best_acc = 0.0
    for i in range(epochs):
        running_loss = 0.
        running_l = 0.
        for images, labels in tqdm(trainloader, leave=False):
            optimizer.zero_grad()
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 784)
            outputs, outputsS = model(images)
            loss = criteria(outputs, outputsS,labels)
            loss.backward()
            l = loss + model.fetch_H_grad()
            optimizer.step()
            running_loss += loss.item()
            running_l += l.item()
        test_acc = Seval()
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "tmp_best.pt")
        print(f"epoch: {i:-3d}, test acc: {test_acc:.4f}, loss: {running_loss / len(trainloader):.4f}, s: {(running_l - running_loss) / len(trainloader):-5.4f}")

def TTrain(epochs, alpha):
    best_acc = 0.0
    for i in range(epochs):
        running_loss = 0.
        running_l = 0.
        for images, labels in tqdm(trainloader, leave=False):
            optimizer.zero_grad()
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 784)
            outputs, outputsS = model(images)
            loss = criteria(outputs, outputsS,labels)
            loss.backward()
            running_loss += loss.item()
            model.do_third(alpha)
            optimizer.step()
        test_acc = Seval()
        print(f"epoch: {i:-3d}, test acc: {test_acc:.4f}, loss: {running_loss / len(trainloader):.4f}")

def GetSecond():
    running_loss = 0.
    running_l = 0.
    optimizer.zero_grad()
    loss_function = SCrossEntropyLoss()
    for images, labels in tqdm(trainloader, leave=False):
        images, labels = images.to(device), labels.to(device)
        images = images.view(-1, 784)
        outputs, outputsS = model(images)
        loss = loss_function(outputs, outputsS,labels)
        loss.backward()

def GetThird():
    running_loss = 0.
    running_l = 0.
    optimizer.zero_grad()
    loss_function = TCrossEntropyLoss()
    for images, labels in tqdm(trainloader, leave=False):
        images, labels = images.to(device), labels.to(device)
        images = images.view(-1, 784)
        outputs, outputsS = model(images)
        loss = loss_function(outputs, outputsS,labels)
        loss.backward()


device = torch.device(args.device if torch.cuda.is_available() else "cpu")

BS = 128

trainset = torchvision.datasets.MNIST(root='~/Private/data', train=True,
                                        download=False, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BS,
                                        shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='~/Private/data', train=False,
                                    download=False, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=BS,
                                            shuffle=False, num_workers=2)

In [3]:
model = SMLP3()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [20])
criteria = OCrossEntropyLoss()
model.to_first()
STrain(10)
model.to_second()
GetSecond()
print(model.fc1.weightH.grad.max())
print(model.fc2.weightH.grad.max())
print(model.fc3.weightH.grad.max())

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   0, test acc: 0.9450, loss: 0.3037, s: 0.0000


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   1, test acc: 0.9533, loss: 0.1630, s: 0.0000


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   2, test acc: 0.9604, loss: 0.1391, s: 0.0000


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   3, test acc: 0.9544, loss: 0.1280, s: 0.0000


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   4, test acc: 0.9557, loss: 0.1163, s: 0.0000


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   5, test acc: 0.9599, loss: 0.1109, s: 0.0000


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   6, test acc: 0.9596, loss: 0.1039, s: 0.0000


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   7, test acc: 0.9639, loss: 0.1014, s: 0.0000


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   8, test acc: 0.9612, loss: 0.0949, s: 0.0000


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   9, test acc: 0.9611, loss: 0.0990, s: 0.0000


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(852.3397)
tensor(45843.7148)
tensor(32539.6758)


In [3]:
model = SMLP3()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [20])

<All keys matched successfully>

In [4]:
state_dict = torch.load("in_use.pt")
model.to_first()
model.load_state_dict(state_dict)
model.to_third()
criteria = TCrossEntropyLoss()
TTrain(1,0)
model.to_second()
GetSecond()
print(model.fc1.weightH.grad.max())
print(model.fc2.weightH.grad.max())
print(model.fc3.weightH.grad.max())

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   0, test acc: 0.9596, loss: 0.1016


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(1540.1842)
tensor(139313.7812)
tensor(29841.2227)


In [5]:
state_dict = torch.load("in_use.pt")
model.to_first()
model.load_state_dict(state_dict)
model.to_third()
criteria = TCrossEntropyLoss()
TTrain(1,1e-10)
model.to_second()
GetSecond()
print(model.fc1.weightH.grad.max())
print(model.fc2.weightH.grad.max())
print(model.fc3.weightH.grad.max())

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   0, test acc: 0.9636, loss: 0.0963


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(1453.2732)
tensor(50952.0117)
tensor(30372.1094)


In [6]:
state_dict = torch.load("in_use.pt")
model.to_first()
model.load_state_dict(state_dict)
model.to_third()
criteria = TCrossEntropyLoss()
TTrain(1,1e-5)
model.to_second()
GetSecond()
print(model.fc1.weightH.grad.max())
print(model.fc2.weightH.grad.max())
print(model.fc3.weightH.grad.max())

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   0, test acc: 0.9618, loss: 0.1010


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(19199.8223)
tensor(118912.6094)
tensor(35265.8984)


In [15]:
state_dict = torch.load("in_use.pt")
model.to_first()
model.load_state_dict(state_dict)
model.to_third()
criteria = TCrossEntropyLoss()
GetThird()
print((model.fc1.weightH.grad).abs().max())
print((model.fc2.weightH.grad).abs().max())
print((model.fc3.weightH.grad).abs().max())
model.to_second()
GetSecond()
print(model.fc1.weightH.grad.max())
print(model.fc2.weightH.grad.max())
print(model.fc3.weightH.grad.max())

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(73.2722)
tensor(36563.8711)
tensor(158410.7344)


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(780.9372)
tensor(27343.7852)
tensor(24844.2266)


In [19]:
state_dict = torch.load("in_use.pt")
model.to_first()
model.load_state_dict(state_dict)
model.to_third()
criteria = TCrossEntropyLoss()
TTrain(10,1e-4)
model.to_second()
GetSecond()
print(model.fc1.weightH.grad.max())
print(model.fc2.weightH.grad.max())
print(model.fc3.weightH.grad.max())

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   0, test acc: 0.9655, loss: 0.0887


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   1, test acc: 0.9679, loss: 0.0630


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   2, test acc: 0.9681, loss: 0.0574


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   3, test acc: 0.9689, loss: 0.0552


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   4, test acc: 0.9656, loss: 0.0553


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   5, test acc: 0.9491, loss: 0.1216


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   6, test acc: 0.9498, loss: 0.1479


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   7, test acc: 0.9278, loss: 0.2219


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   8, test acc: 0.9333, loss: 0.3461


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

epoch:   9, test acc: 0.9207, loss: 0.4246


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))

tensor(163653.7969)
tensor(444814.7500)
tensor(633621.7500)
