In [1]:
import torch
from torch import autograd
from torch import nn
import torchvision
from torch import optim
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
import numpy as np
from models import SCrossEntropyLoss, OCrossEntropyLoss, SMLP3, SMLP4
from Functions import SCrossEntropyLossFunction
from tqdm.notebook import tqdm

In [15]:
def eval():
    total = 0
    correct = 0
    model.clear_noise()
    model.clear_mask()
    with torch.no_grad():
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 784)
            outputs = model(images)
            predictions = outputs.argmax(dim=1)
            correction = predictions == labels
            correct += correction.sum()
            total += len(correction)
    return (correct/total).item()

def Seval(is_clear_mask=True):
    total = 0
    correct = 0
    with torch.no_grad():
        model.clear_noise()
        if is_clear_mask:
            model.clear_mask()
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 784)
            outputs = model(images)
            predictions = outputs[0].argmax(dim=1)
            correction = predictions == labels
            correct += correction.sum()
            total += len(correction)
    return (correct/total).item()

def Seval_noise(var, is_clear_mask=True):
    total = 0
    correct = 0
    model.clear_noise()
    if is_clear_mask:
        model.clear_mask()
    with torch.no_grad():
        model.set_noise(var)
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 784)
            outputs = model(images)
            predictions = outputs[0].argmax(dim=1)
            correction = predictions == labels
            correct += correction.sum()
            total += len(correction)
    return (correct/total).item()

def STrain(epochs):
    best_acc = 0.0
    for i in range(epochs):
        running_loss = 0.
        running_l = 0.
        for images, labels in tqdm(trainloader):
            optimizer.zero_grad()
            images, labels = images.to(device), labels.to(device)
            images = images.view(-1, 784)
            outputs, outputsS = model(images)
            loss = criteria(outputs, outputsS,labels)
            loss.backward()
            l = loss + model.fetch_H_grad()
            optimizer.step()
            running_loss += loss.item()
            running_l += l.item()
        test_acc = Seval()
        if test_acc > best_acc:
            best_acc = test_acc
            torch.save(model.state_dict(), "tmp_best.pt")
        print(f"epoch: {i:-3d}, test acc: {test_acc:.4f}, loss: {running_loss / len(trainloader):.4f}, s: {(running_l - running_loss) / len(trainloader):-5.4f}")

def GetSecond():
    running_loss = 0.
    running_l = 0.
    optimizer.zero_grad()
    loss_function = SCrossEntropyLoss()
    for images, labels in tqdm(trainloader):
        images, labels = images.to(device), labels.to(device)
        images = images.view(-1, 784)
        outputs, outputsS = model(images)
        loss = loss_function(outputs, outputsS,labels)
        loss.backward()


device = torch.device(args.device if torch.cuda.is_available() else "cpu")

BS = 128

trainset = torchvision.datasets.MNIST(root='~/Private/data', train=True,
                                        download=False, transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BS,
                                        shuffle=True, num_workers=2)

testset = torchvision.datasets.MNIST(root='~/Private/data', train=False,
                                    download=False, transform=transforms.ToTensor())
testloader = torch.utils.data.DataLoader(testset, batch_size=BS,
                                            shuffle=False, num_workers=2)

In [21]:
model = SMLP3()
optimizer = optim.Adam(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [20])
# criteria = SCrossEntropyLoss()
criteria = OCrossEntropyLoss()
model.to_first()
STrain(1)

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


epoch:   0, test acc: 0.9400, loss: 0.3093, s: 0.0000


In [22]:
model.to_second()
GetSecond()
print(model.fc1.weightH.grad.max())
print(model.fc2.weightH.grad.max())
print(model.fc3.weightH.grad.max())

HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))


tensor(369.9479)
tensor(23018.9043)
tensor(38573.5938)


In [23]:
model.set_mask(10000,"th")
Seval(False)

0.858299970626831

In [16]:
GetSecond()


HBox(children=(FloatProgress(value=0.0, max=469.0), HTML(value='')))




tensor(13375.1816)