In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.nn.functional as F
from copy import deepcopy

n_epochs = 5 # number of epochs for training
batch_size_train = 1024 # batch size for training
batch_size_test = 8192 # batch size for testing
learning_rate = 0.001 # learning rate for Adam
log_interval = 10 # logging interval for metrics
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # device for computation

# fixing random seed
random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

<torch._C.Generator at 0x1554cbb6b390>

In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.f1 = F1()
        self.f2 = F2()
    
    def forward(self, x, y_hat, return_mid=False):
        z = self.f1(x)
        y_pred, z_pred  = self.f2(z, y_hat)
        return y_pred, z_pred
    
    
class F1(nn.Module):
    def __init__(self):
        super(F1, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.activation = nn.LeakyReLU(negative_slope=0.01)

    def forward(self, x):
        x = self.activation(F.max_pool2d(self.conv1(x), 2))
        x = self.activation(F.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 320)
        x = self.activation(self.fc1(x))
        return x
    
    
class F2(nn.Module):
    def __init__(self):
        super(F2, self).__init__()
        self.fc2 = nn.Linear(50, 50)
        self.fc3 = nn.Linear(50, 60)
        self.y_hat_fc = nn.Linear(10, 50)
        self.activation = nn.LeakyReLU(negative_slope=0.01)

    def forward(self, x, y_hat):
        x = x + self.activation(self.y_hat_fc(y_hat))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        return x[:, :10].softmax(-1), x[:, 10:]

In [3]:
def js_div(p, q):
    m = 0.5 * (p + q)
    return 0.5 * (F.kl_div(torch.log(p), m, reduction='batchmean') + 
                  F.kl_div(torch.log(q), m, reduction='batchmean'))

def train(f, train_loader, optimizer, n_epochs, n_classes=10):
  f.train()
  for epoch in range(1, n_epochs + 1):
    for batch_idx, (data, target) in enumerate(train_loader):
      # randomally choose y0 to be either the true target y or a random class
      x, y = data.to(device), target.to(device)
      y0_d = y if torch.rand(1) > 0.25 else torch.randint(0, n_classes, (x.shape[0],), device=device)
      y0 = F.one_hot(y0_d, num_classes=n_classes).float()

      # forward pass
      y1, z1 = f(x, y0)
      y2, z2 = f.f2(z1, y1)

      # losses
      loss_supervised_1 = F.nll_loss(y1.log(), y)
      loss_supervised_2 = F.nll_loss(y2.log(), y)
      loss_unsupervised_y = js_div(y1, y2)*10
      loss_unsupervised_z = (z1 - z2).pow(2).mean()*10
      # loss = loss_supervised + #loss_unsupervised_y # + loss_unsupervised_z
      loss = loss_supervised_1 + loss_supervised_2

      # opt
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      # log
      if batch_idx % log_interval == 0:
        # print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, loss_supervised: {:.6f}, loss_unsupervised_y: {:.6f}, loss_unsupervised_z: {:.6f}'.format(
        #   epoch, batch_idx * len(data), len(train_loader.dataset),
        #   100. * batch_idx / len(train_loader), loss.item(), loss_supervised.item(), loss_unsupervised_y.item(), loss_unsupervised_z.item()))
        print('Train Epoch: {} [{}/{} ({:.0f}%)]\tloss_supervised_1: {:.6f}, loss_supervised_2: {:.6f}'.format(
          epoch, batch_idx * len(data), len(train_loader.dataset),
          100. * batch_idx / len(train_loader), loss_supervised_1.item(), loss_supervised_2.item()))
        torch.save(f.state_dict(), './model.pth')
        torch.save(optimizer.state_dict(), './optimizer.pth')
      

def test(f, test_loader, n_classes=10, with_true_y=False):
  f.eval()
  test_loss, correct = 0, 0
  with torch.no_grad():
    for data, target in test_loader:
      x, y = data.to(device), target.to(device)
      if with_true_y:
        y0 = F.one_hot(y, num_classes=n_classes).float()
      else:
        y0 = torch.ones((len(y), n_classes), device=x.device).to(device).softmax(-1)
      y1, z1 = f(x, y0)
      test_loss += F.nll_loss(y1.log(), y, size_average=False).item()
      pred = y1.log().data.max(1, keepdim=True)[1]
      correct += pred.eq(y.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  print("True y given") if with_true_y else print("Max entropy given")
  print('\nTest: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))
  return 100 * correct / len(test_loader.dataset)

# Model Training

In [4]:




# def train(network, train_loader, optimizer, n_epochs, n_classes=10):
#   network.train()
#   train_losses, train_counter = [], []
#   for epoch in range(1, n_epochs + 1):
#     for batch_idx, (data, target) in enumerate(train_loader):
#       # randomally choose y_hat_0 to be either the target y or max entropy
#       x, y = data.to(device), target.to(device)
#       y_hat_opt_a = F.one_hot(y, num_classes=n_classes).float()
#       y_hat_opt_b = torch.ones_like(y_hat_opt_a).softmax(-1)
#       r = (torch.rand(x.shape[0], 1, device=device) > 0).float()
#       y_hat_0 = y_hat_opt_a * r + y_hat_opt_b * (1 - r)
      
#       # apply network, our prediction is called y_hat_1
#       y_hat_1 = network(x, y_hat_0)
#       loss = F.nll_loss(y_hat_1.log(), y)

#       # opt
#       optimizer.zero_grad()
#       loss.backward()
#       optimizer.step()

#       # log
#       if batch_idx % log_interval == 0:
#         print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#           epoch, batch_idx * len(data), len(train_loader.dataset),
#           100. * batch_idx / len(train_loader), loss.item()))
#         train_losses.append(loss.item())
#         train_counter.append(
#           (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))
#         torch.save(network.state_dict(), './model.pth')
#         torch.save(optimizer.state_dict(), './optimizer.pth')



# def test(network, test_loader, n_classes=10, with_true_y=False):
#   network.eval()
#   test_loss, correct, test_losses = 0, 0, []
#   with torch.no_grad():
#     for data, target in test_loader:
#       x, y = data.to(device), target.to(device)
#       if with_true_y:
#         y_hat_0 = F.one_hot(y, num_classes=n_classes).float()
#       else:
#         y_hat_0 = torch.ones((len(y), n_classes)).to(device).softmax(-1)
#       y_hat_1 = network(x, y_hat_0)
#       test_loss += F.nll_loss(y_hat_1.log(), y, size_average=False).item()
#       pred = y_hat_1.log().data.max(1, keepdim=True)[1]
#       correct += pred.eq(y.data.view_as(pred)).sum()
#   test_loss /= len(test_loader.dataset)
#   test_losses.append(test_loss)
#   print("True y given") if with_true_y else print("Max entropy given")
#   print('\nTest: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
#     test_loss, correct, len(test_loader.dataset),
#     100. * correct / len(test_loader.dataset)))
#   return 100 * correct / len(test_loader.dataset)

In [5]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

f = Net().to(device)
optimizer = optim.Adam(f.parameters(), lr=1e-3)

test(f, test_loader)
train(f, train_loader, optimizer, n_epochs)
test(f, test_loader)
test(f, test_loader, with_true_y=True)



Max entropy given

Test: Avg. loss: 2.3083, Accuracy: 1116/10000 (11%)

Max entropy given

Test: Avg. loss: 0.0365, Accuracy: 9876/10000 (99%)

True y given

Test: Avg. loss: 0.0071, Accuracy: 9977/10000 (100%)



tensor(99.7700, device='cuda:0')

## TTT

In [6]:
def ttt_one_instance(x, f_ttt, f, optimizer, n_steps, y, n_classes=10):
  f_ttt.load_state_dict(f.state_dict())  # reset f_ttt to f
  f_ttt.train()
  for step in range(n_steps):    
    y0 = F.one_hot(torch.randint(0, n_classes, (x.shape[0],), device=device), num_classes=n_classes).float()
    y1, z1 = f_ttt(x, y0)
    y2, z2 = f.f2(z1, y1)
    
    loss_unsupervised_y = js_div(y1, y2)
    loss_unsupervised_z = (z1 - z2).pow(2).mean()
    loss = loss_unsupervised_y #+ loss_unsupervised_z
        
    
    if y[0].item() != y1[0].argmax().item() and (step == 0 or step == n_steps - 1):
      print(f'step {step}: loss={loss.item()}')
      print(y0[0].argmax().item(), y1[0].argmax().item(), y2[0].argmax().item(), y[0].item())
    optimizer.zero_grad()
    loss.backward()
    if f_ttt.f2.fc2.weight.grad.var() == 0:
      print('zero grad')
    optimizer.step()
  return y1, y2


def ttt(f, test_loader, n_steps, lr):
  f_ttt = deepcopy(f)
  f.eval()
  optimizer = optim.Adam(f_ttt.parameters(), lr=lr)
  test_loss_1, correct_1 = 0, 0
  test_loss_2, correct_2 = 0, 0

  for ind, (data, target) in enumerate(test_loader):
    print(f'batch {ind}/{len(test_loader)}:')
    x, y = data.to(device), target.to(device)
    y_hat_1, y_hat_2 = ttt_one_instance(x, f_ttt, f, optimizer, n_steps, y)

    test_loss_1 += F.nll_loss(y_hat_1.log(), y, size_average=False).item()
    test_loss_2 += F.nll_loss(y_hat_2.log(), y, size_average=False).item()

    pred_1 = y_hat_1.data.max(1, keepdim=True)[1]
    pred_2 = y_hat_2.data.max(1, keepdim=True)[1]

    correct_1 += pred_1.eq(y.data.view_as(pred_1)).sum()
    correct_2 += pred_2.eq(y.data.view_as(pred_2)).sum()

  test_loss_1 /= len(test_loader.dataset)
  test_loss_2 /= len(test_loader.dataset)



  print('\nttt y_hat_1: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss_1, correct_1, len(test_loader.dataset),
    100. * correct_1 / len(test_loader.dataset)))
  print('\nttt y_hat_2: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss_2, correct_2, len(test_loader.dataset),
    100. * correct_2 / len(test_loader.dataset)))


In [7]:
# def ttt_one_instance(x, f_ttt, f_copy, optimizer, n_steps, n_classes=10):
#   for step in range(n_steps):
#     f_copy.load_state_dict(f_ttt.state_dict())
#     y_hat_0 = torch.ones((len(x), 10)).to(device) / n_classes
#     y_hat_1, x_mid = f_ttt(x, y_hat_0, return_mid=True)
#     y_hat_2 = f_copy.f2(x_mid, y_hat_1)
#     y_hat_2_adv = f_ttt.f2(x_mid.detach(), y_hat_1.detach())
    
#     # loss_const = js_divergence(y_hat_1, y_hat_2)
#     # loss_adv = -js_divergence(y_hat_1.detach(), y_hat_2_adv)*100
#     loss_adv = -(y_hat_1.detach() - y_hat_2).pow(2).mean()*100
#     loss_const = (y_hat_1 - y_hat_2).pow(2).mean()
#     # args = y_hat_1.argmax(dim=1)
#     # loss_adv = -(y_hat_1[args].detach() - y_hat_2_adv[args]).pow(2).mean()
#     loss_entropy = -(y_hat_1 * torch.log(y_hat_1 + 1e-10)).sum(-1).mean()*0.01
#     loss = loss_const + loss_adv + loss_entropy

#     print(y_hat_1[0].argmax().item(), y_hat_2[0].argmax().item(), y_hat_2_adv[0].argmax().item())
    
#     # loss = (y_hat_1 - y_hat_2).pow(2).mean()
    
#     print(f'step {step}: loss={loss.item()}')
#     optimizer.zero_grad()
#     loss.backward()
#     if f_ttt.f2.fc2.weight.grad.var() == 0:
#       print('zero grad')
#     optimizer.step()
#   return y_hat_1, y_hat_2


# def ttt(f, test_loader, n_steps, lr):
#   f_ttt = deepcopy(f)
#   f_copy = deepcopy(f)
#   f.eval()
#   f_copy.train()
#   f_ttt.train()
#   optimizer = optim.Adam(f_ttt.parameters(), lr=lr)
#   test_loss_1, correct_1, test_losses_1 = 0, 0, []
#   test_loss_2, correct_2, test_losses_2 = 0, 0, []

#   for data, target in test_loader:
#     x, y = data.to(device), target.to(device)
#     f_ttt.load_state_dict(f.state_dict())  # reset f_ttt to f
#     f_ttt.train()

#     y_hat_1, y_hat_2 = ttt_one_instance(x, f_ttt, f_copy, optimizer, n_steps)

#     test_loss_1 += F.nll_loss(y_hat_1.log(), y, size_average=False).item()
#     test_loss_2 += F.nll_loss(y_hat_2.log(), y, size_average=False).item()

#     pred_1 = y_hat_1.data.max(1, keepdim=True)[1]
#     pred_2 = y_hat_2.data.max(1, keepdim=True)[1]

#     correct_1 += pred_1.eq(y.data.view_as(pred_1)).sum()
#     correct_2 += pred_2.eq(y.data.view_as(pred_2)).sum()

#   test_loss_1 /= len(test_loader.dataset)
#   test_loss_2 /= len(test_loader.dataset)

#   test_losses_1.append(test_loss_1)
#   test_losses_2.append(test_loss_2)

#   print('\nttt y_hat_1: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
#     test_loss_1, correct_1, len(test_loader.dataset),
#     100. * correct_1 / len(test_loader.dataset)))
#   print('\nttt y_hat_2: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
#     test_loss_2, correct_2, len(test_loader.dataset),
#     100. * correct_2 / len(test_loader.dataset)))

In [23]:
class AddGaussianNoise(object):
    def __init__(self, std, mean=0.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
    AddGaussianNoise(1.75)])
ood_dataset = datasets.MNIST('data', train=False, download=True, transform=transform)
ood_loader = torch.utils.data.DataLoader(ood_dataset, batch_size=64, shuffle=False)

test(f, ood_loader)
ttt(f, ood_loader, n_steps=12, lr=1e-3)

Max entropy given

Test: Avg. loss: 0.7275, Accuracy: 7525/10000 (75%)

batch 0/157:
batch 1/157:
batch 2/157:
batch 3/157:
batch 4/157:
batch 5/157:
step 0: loss=0.012889975681900978
8 8 8 9
step 11: loss=0.004932774696499109
7 7 7 9
batch 6/157:
batch 7/157:
step 0: loss=0.012837264686822891
9 8 8 9
step 11: loss=0.004485766869038343
7 3 3 9
batch 8/157:
batch 9/157:
batch 10/157:
batch 11/157:
batch 12/157:
batch 13/157:
batch 14/157:
batch 15/157:
step 0: loss=0.01212082989513874
5 5 5 7
batch 16/157:
step 0: loss=0.012450642883777618
6 8 8 4
batch 17/157:
batch 18/157:
step 0: loss=0.01734895259141922
2 8 8 9
batch 19/157:
step 0: loss=0.012396005913615227
1 1 1 7
step 11: loss=0.0034492812119424343
0 1 1 7
batch 20/157:
step 0: loss=0.010795621201395988
7 7 7 1
step 11: loss=0.0032694111578166485
3 2 3 1
batch 21/157:
batch 22/157:
batch 23/157:
step 11: loss=0.004042170941829681
9 5 5 3
batch 24/157:
step 0: loss=0.010347122326493263
7 8 8 6
batch 25/157:
step 0: loss=0.01027323

In [9]:
test(network, test_loader)

NameError: name 'network' is not defined