![image.png](attachment:image.png)

### classification loss(CE) + consistency loss(mse between student & teacher outputs)로 student model backpropagation
### 이 후 student weight 의 EMA를 teacher에 덮어씌움으로 teacher weight update
### 이 코드는 중간에 들어가는 noise는 딱히 없어보임.. 
### temporal ensemble은 1에폭마다 z 정보 갱신.. (iter별로는 z의 해당부분 loss구하는데 사용 & outputs에 하나씩 쌓음)
### mean teacher는 매 iter마다 teacher의 weight 파라미터 갱신 -> 더 빠른 학습 가능

In [1]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [2]:
root = '~/datasets/MNIST'

train_dataset = datasets.MNIST(
        root=root,
        train=True,
        transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                           ]),
        download=True)

# load test data
test_dataset = datasets.MNIST(
    root=root,
    train=False,
    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                       ]),
    download=True)

In [3]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                               batch_size=64,
                                               num_workers=2,
                                               shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                              batch_size=64,
                                              num_workers=2,
                                              shuffle=False)

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

In [5]:
def train(args, model, mean_teacher, device, train_loader, test_loader, optimizer, epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        output = model(data)

        # forward pass with mean teacher
        # torch.no_grad() prevents gradients from being passed into mean teacher model
        with torch.no_grad():
            mean_t_output = mean_teacher(data)

        # consistency loss (example with MSE, you can change)
        const_loss = F.mse_loss(output, mean_t_output)

        # set the consistency weight (should schedule)
        weight = 0.2
        loss = F.nll_loss(output, target) + weight*const_loss
        loss.backward()
        optimizer.step()

        ########################### CODE CHANGE HERE ######################################
        # update mean teacher, (should choose alpha somehow)
        # Use the true average until the exponential average is more correct
        alpha = 0.95
        for mean_param, param in zip(mean_teacher.parameters(), model.parameters()):
            mean_param.data.mul_(alpha).add_(1 - alpha, param.data) # teacher_param = alpha * teacher_param + (1-alpha) * student_param

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))
            test(args, model, device, test_loader)
            test(args, mean_teacher, device, test_loader)


def test(args, model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    model.train()

In [6]:
class arg():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 1000
        self.epochs = 10
        self.lr = 0.01
        self.momentum = 0.5
        self.no_cuda = False
        self.seed = 1
        self.log_interval = 100
        self.save_model = False

In [7]:
def main():
    # Training settings
    args = arg()
    

    use_cuda = not args.no_cuda and torch.cuda.is_available()

    torch.manual_seed(args.seed)

    device = torch.device("cuda" if use_cuda else "cpu")

    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}


    model = Net().to(device)
    ########################### CODE CHANGE HERE ######################################
    # initialize mean teacher
    mean_teacher = Net().to(device)
    optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

    for epoch in range(1, args.epochs + 1):
        train(args, model, mean_teacher, device, train_loader, test_loader, optimizer, epoch)

    if (args.save_model):
        torch.save(model.state_dict(), "mnist_cnn.pt")


main()

	add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
	add_(Tensor other, *, Number alpha) (Triggered internally at  ..\torch\csrc\utils\python_arg_parser.cpp:882.)



Test set: Average loss: 2.2988, Accuracy: 1165/10000 (12%)


Test set: Average loss: 2.3035, Accuracy: 992/10000 (10%)


Test set: Average loss: 0.7570, Accuracy: 8546/10000 (85%)


Test set: Average loss: 0.9504, Accuracy: 8443/10000 (84%)


Test set: Average loss: 0.4120, Accuracy: 9118/10000 (91%)


Test set: Average loss: 0.4448, Accuracy: 9121/10000 (91%)


Test set: Average loss: 0.2878, Accuracy: 9375/10000 (94%)


Test set: Average loss: 0.3025, Accuracy: 9364/10000 (94%)


Test set: Average loss: 0.2318, Accuracy: 9501/10000 (95%)


Test set: Average loss: 0.2318, Accuracy: 9512/10000 (95%)


Test set: Average loss: 0.1878, Accuracy: 9572/10000 (96%)


Test set: Average loss: 0.1909, Accuracy: 9586/10000 (96%)


Test set: Average loss: 0.1609, Accuracy: 9642/10000 (96%)


Test set: Average loss: 0.1622, Accuracy: 9639/10000 (96%)


Test set: Average loss: 0.1429, Accuracy: 9671/10000 (97%)


Test set: Average loss: 0.1422, Accuracy: 9687/10000 (97%)


Test set: Average loss: 


Test set: Average loss: 0.0369, Accuracy: 9900/10000 (99%)


Test set: Average loss: 0.0361, Accuracy: 9903/10000 (99%)


Test set: Average loss: 0.0356, Accuracy: 9902/10000 (99%)


Test set: Average loss: 0.0355, Accuracy: 9902/10000 (99%)


Test set: Average loss: 0.0357, Accuracy: 9906/10000 (99%)


Test set: Average loss: 0.0349, Accuracy: 9904/10000 (99%)


Test set: Average loss: 0.0354, Accuracy: 9896/10000 (99%)


Test set: Average loss: 0.0351, Accuracy: 9897/10000 (99%)


Test set: Average loss: 0.0354, Accuracy: 9901/10000 (99%)


Test set: Average loss: 0.0351, Accuracy: 9898/10000 (99%)


Test set: Average loss: 0.0356, Accuracy: 9895/10000 (99%)


Test set: Average loss: 0.0344, Accuracy: 9901/10000 (99%)


Test set: Average loss: 0.0336, Accuracy: 9902/10000 (99%)


Test set: Average loss: 0.0338, Accuracy: 9903/10000 (99%)


Test set: Average loss: 0.0334, Accuracy: 9904/10000 (99%)


Test set: Average loss: 0.0333, Accuracy: 9904/10000 (99%)


Test set: Average loss:


Test set: Average loss: 0.0265, Accuracy: 9920/10000 (99%)


Test set: Average loss: 0.0263, Accuracy: 9921/10000 (99%)


Test set: Average loss: 0.0258, Accuracy: 9923/10000 (99%)


Test set: Average loss: 0.0258, Accuracy: 9924/10000 (99%)


Test set: Average loss: 0.0254, Accuracy: 9924/10000 (99%)


Test set: Average loss: 0.0255, Accuracy: 9924/10000 (99%)


Test set: Average loss: 0.0257, Accuracy: 9923/10000 (99%)


Test set: Average loss: 0.0253, Accuracy: 9925/10000 (99%)

