In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim

In [2]:
torch.cuda.is_available() # gpu 사용 확인

True

In [3]:
use_cuda=torch.cuda.is_available()
device=torch.device("cuda" if use_cuda else "cpu")
device #cuda

device(type='cuda')

이미지 데이터에 대한 데이터 전처리

In [4]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding = 4),
    transforms.ToTensor(),
    transforms.Normalize((0.507, 0.487, 0.441),(0.267, 0.256, 0.276))
])

Dataset 불러오기

In [5]:
train_dataset = datasets.CIFAR10(root= '/data', train = True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root= '/data', train = False, download=True, transform=transform)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 43707887.30it/s]


Extracting /data/cifar-10-python.tar.gz to /data
Files already downloaded and verified


Dataloader 생성


1.   batchsize = 128
2.   Trainset은 shuffle, Testset은 no-shuffle


In [6]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle = True, num_workers=2)
test_loader = torch.utils.data.DatbaLoader(test_dataset, batch_size=128, shuffle = False, num_workers=2)

하이퍼파라미터 설정

In [11]:
epochs = 150
best_acc = 0
learning_rate = 0.01
momentum = 0.9
weight_decay = 0.0001
criterion = nn.CrossEntropyLoss()

Teacher model 생성 및 학습

In [14]:
class Teacher(nn.Module):
  def __init__(self, num_classes=10):
    super(Teacher, self).__init__()
    self.features = nn.Sequential(
        nn.Conv2d(3,32, kernel_size=3, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(),

        nn.Conv2d(32,32, kernel_size=3, padding=1),
        nn.BatchNorm2d(32),
        nn.ReLU(),

        nn.MaxPool2d(kernel_size=2, stride =2),

        nn.Conv2d(32,64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),

        nn.Conv2d(64,64, kernel_size=3, padding=1),
        nn.BatchNorm2d(64),
        nn.ReLU(),

        nn.MaxPool2d(kernel_size=2, stride =2),


        nn.Conv2d(64,128,kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(),

        nn.Conv2d(128,128,kernel_size=3, padding=1),
        nn.BatchNorm2d(128),
        nn.ReLU(),

        nn.MaxPool2d(kernel_size=2, stride =2),


        nn.Conv2d(128,256,kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(),

        nn.Conv2d(256,256,kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(),

        nn.Conv2d(256,256,kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(),

        nn.Conv2d(256,256,kernel_size=3, padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(),

        nn.MaxPool2d(kernel_size=2, stride =2),
    )

    self.fc_layers = nn.Sequential(
        nn.Linear(1024,128),
        nn.Linear(128,10)
     )
  def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x

teacher = Teacher().to(device) #Teacher모델 gpu에 생성
optimizer = optim.SGD(teacher.parameters(),lr=learning_rate,momentum=momentum,weight_decay=weight_decay)

for epoch in range(epochs):
  teacher.train()
  train_loss = 0
  correct = 0
  total = 0

  for batch_idx, (data, target) in enumerate(train_loader):
    data = data.to(device)
    target = target.to(device)
    optimizer.zero_grad()
    output = teacher(data)
    loss = criterion(output,target)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    _, predicted = torch.max(output.data, 1)
    total += target.size(0)
    correct += predicted.eq(target.data).cpu().sum()
    if batch_idx % 100 == 0:
      print('Epoch: {} | Batch_idx: {} |  Loss: ({:.4f}) | Acc: ({:.2f}%) ({}/{})'
      .format(epoch, batch_idx, train_loss / (batch_idx + 1), 100. * correct / total, correct, total))

print("============Training finished=============")


teacher.eval()  # 모델 평가모드
with torch.no_grad(): #no gradient
  correct = 0
  val_acc = 0
  total = 0
  for images, labels in test_loader:
    images = images.to(device)
    labels = labels.to(device)
    outputs = teacher(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

  val_acc = 100 * correct / total
  print('Accuracy on the test set: {}'.format(val_acc))

if val_acc > best_acc:
  best_acc = val_acc
  torch.save({
      'epoch': epoch,
      'model_state_dict': teacher.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss.item(),
      }, '/content/teacher_model.pth') # 모델 epochs, weight, opimizer 상태,loss 값 등 체크포인트 저장



Epoch: 0 | Batch_idx: 0 |  Loss: (2.4015) | Acc: (9.38%) (12/128)
Epoch: 0 | Batch_idx: 100 |  Loss: (1.7316) | Acc: (35.70%) (4615/12928)
Epoch: 0 | Batch_idx: 200 |  Loss: (1.5593) | Acc: (42.76%) (11002/25728)
Epoch: 0 | Batch_idx: 300 |  Loss: (1.4403) | Acc: (47.35%) (18242/38528)
Epoch: 1 | Batch_idx: 0 |  Loss: (0.9853) | Acc: (64.06%) (82/128)
Epoch: 1 | Batch_idx: 100 |  Loss: (0.9630) | Acc: (65.69%) (8493/12928)
Epoch: 1 | Batch_idx: 200 |  Loss: (0.9305) | Acc: (67.23%) (17296/25728)
Epoch: 1 | Batch_idx: 300 |  Loss: (0.9100) | Acc: (67.94%) (26176/38528)
Epoch: 2 | Batch_idx: 0 |  Loss: (0.6201) | Acc: (79.69%) (102/128)
Epoch: 2 | Batch_idx: 100 |  Loss: (0.7532) | Acc: (73.48%) (9499/12928)
Epoch: 2 | Batch_idx: 200 |  Loss: (0.7367) | Acc: (73.98%) (19034/25728)
Epoch: 2 | Batch_idx: 300 |  Loss: (0.7355) | Acc: (74.11%) (28554/38528)
Epoch: 3 | Batch_idx: 0 |  Loss: (0.6999) | Acc: (78.12%) (100/128)
Epoch: 3 | Batch_idx: 100 |  Loss: (0.6558) | Acc: (77.08%) (9965/12

Teacher 및 Student Model 생성

학습한 Teacher, student model, KD모델 경로 지정

모델 학습 함수

In [17]:
class Student(nn.Module):
   def __init__(self, num_classes=10):
      super(Student,self).__init__()
      self.features = nn.Sequential(
         nn.Conv2d(3,16,kernel_size=3, padding =1),
         nn.BatchNorm2d(16),
         nn.ReLU(),

         nn.MaxPool2d(kernel_size=2, stride=2),


         nn.Conv2d(16,16,kernel_size=3, padding =1),
         nn.BatchNorm2d(16),
         nn.ReLU(),

         nn.MaxPool2d(kernel_size=2, stride=2),

        )
      self.fc_layers = nn.Linear(1024,10)
   def forward(self, x):
      x = self.features(x)
      x = x.view(x.size(0), -1)
      x= self.fc_layers(x)
      return x


student = Student().to(device)  #Student모델 gpu에 생성

optimizer = optim.SGD(student.parameters(),lr=learning_rate,momentum=momentum,weight_decay=weight_decay)

for epoch in range(epochs):
  student.train()
  train_loss = 0
  correct = 0
  total = 0

  for batch_idx, (data, target) in enumerate(train_loader):
    data = data.to(device)
    target = target.to(device)
    optimizer.zero_grad()
    output = student(data)
    loss = criterion(output,target)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    _, predicted = torch.max(output.data, 1)
    total += target.size(0)
    correct += predicted.eq(target.data).cpu().sum()
    if batch_idx % 100 == 0:
      print('Epoch: {} | Batch_idx: {} |  Loss: ({:.4f}) | Acc: ({:.2f}%) ({}/{})'
      .format(epoch, batch_idx, train_loss / (batch_idx + 1), 100. * correct / total, correct, total))

print("============Training finished=============")

student.eval()  # 모델 평가모드
with torch.no_grad(): #no gradient
  correct = 0
  val_acc = 0
  total = 0
  for images, labels in test_loader:
    images = images.to(device)
    labels = labels.to(device)
    outputs = student(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

  val_acc = 100 * correct / total
  print('Accuracy on the test set: {}'.format(val_acc))
if val_acc > best_acc:
  best_acc = val_acc
  torch.save({
      'epoch': epoch,
      'model_state_dict': student.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss.item(),
      }, '/content/student_model.pth') # 모델 epochs, weight, opimizer 상태,loss 값 등 체크포인트 저장

Epoch: 0 | Batch_idx: 0 |  Loss: (2.5023) | Acc: (12.50%) (16/128)
Epoch: 0 | Batch_idx: 100 |  Loss: (1.8115) | Acc: (34.43%) (4451/12928)
Epoch: 0 | Batch_idx: 200 |  Loss: (1.6843) | Acc: (38.93%) (10015/25728)
Epoch: 0 | Batch_idx: 300 |  Loss: (1.6027) | Acc: (41.99%) (16177/38528)
Epoch: 1 | Batch_idx: 0 |  Loss: (1.4834) | Acc: (48.44%) (62/128)
Epoch: 1 | Batch_idx: 100 |  Loss: (1.3235) | Acc: (53.02%) (6855/12928)
Epoch: 1 | Batch_idx: 200 |  Loss: (1.3072) | Acc: (53.51%) (13766/25728)
Epoch: 1 | Batch_idx: 300 |  Loss: (1.2970) | Acc: (54.04%) (20820/38528)
Epoch: 2 | Batch_idx: 0 |  Loss: (1.3878) | Acc: (46.09%) (59/128)
Epoch: 2 | Batch_idx: 100 |  Loss: (1.2439) | Acc: (56.08%) (7250/12928)
Epoch: 2 | Batch_idx: 200 |  Loss: (1.2245) | Acc: (56.35%) (14499/25728)
Epoch: 2 | Batch_idx: 300 |  Loss: (1.2223) | Acc: (56.80%) (21884/38528)
Epoch: 3 | Batch_idx: 0 |  Loss: (1.2215) | Acc: (57.81%) (74/128)
Epoch: 3 | Batch_idx: 100 |  Loss: (1.1630) | Acc: (59.41%) (7681/129

Knowledge Distillation Train





In [18]:
trained_teacher = Teacher().to(device)
model_ckp = torch.load('/content/teacher_model.pth')
trained_teacher.load_state_dict(model_ckp['model_state_dict'])
#teacher model 새로 생성후 teacher_checkpoint load를 가져와서 trained_teacher에 적용



#Knowledge distillation을 위한 parameters (lamda_, T)
lambda_ =0.0001
T = 4.5
kl_div_loss = nn.KLDivLoss() #Knowledge distillation을 위한 Cost function


for epoch in range(epochs):
    student.train() #위에 train된 student 모델이 아니라 위에 만든 student를 그대로 다시 적용
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
      data = data.to(device)
      target = target.to(device)
      optimizer.zero_grad() #optimizer = optim.SGD(student.parameters(),lr=learning_rate,momentum=momentum,weight_decay=weight_decay)
      output = student(data)
      loss_SL = criterion(output,target) #Standard Learning loss
      teacher_outputs =trained_teacher(data)
      loss_KD = kl_div_loss(F.log_softmax(output / T, dim=1),
                            F.softmax(teacher_outputs / T, dim=1)) # loss_KD =KLDivLoss(log_softmax(output/T),softmax(teacher_outputs/T))
      loss = (1 - lambda_) * loss_SL + lambda_ * T * T * loss_KD  # total_loss = (1 −λ)⋅loss_SL +λ⋅T^2 ⋅loss_KD)
      loss.backward()
      optimizer.step()
      train_loss += loss.item()
      _, predicted = torch.max(output.data, 1)
      total += target.size(0)
      correct += predicted.eq(target.data).cpu().sum()
      if batch_idx % 100 == 0:
          print('Epoch: {} | Batch_idx: {} |  Loss: ({:.4f}) | Acc: ({:.2f}%) ({}/{})'
          .format(epoch, batch_idx, train_loss / (batch_idx + 1), 100. * correct / total, correct, total))
print("============Training finished=============")


student.eval()  # 모델 평가모드
with torch.no_grad(): #no gradient
     correct = 0
     val_acc = 0
     total = 0
     for images, labels in test_loader:
         images = images.to(device)
         labels = labels.to(device)
         outputs = student(images)
         _, predicted = torch.max(outputs.data, 1)
         total += labels.size(0)
         correct += (predicted == labels).sum().item()

         val_acc = 100 * correct / total
         print('Accuracy on the test set: {}'.format(val_acc))
if val_acc > best_acc:
  best_acc = val_acc
  torch.save({
      'epoch': epoch,
      'model_state_dict': student.state_dict(),
      'optimizer_state_dict': optimizer.state_dict(),
      'loss': loss.item(),
      }, '/content/KD_model.pth') # KD 적용한 student 모델의 epochs, weight, opimizer 상태,loss 값 등 체크포인트 저장




Epoch: 0 | Batch_idx: 0 |  Loss: (0.9446) | Acc: (66.41%) (85/128)
Epoch: 0 | Batch_idx: 100 |  Loss: (0.8893) | Acc: (68.84%) (8899/12928)
Epoch: 0 | Batch_idx: 200 |  Loss: (0.8836) | Acc: (68.86%) (17716/25728)
Epoch: 0 | Batch_idx: 300 |  Loss: (0.8822) | Acc: (69.15%) (26641/38528)
Epoch: 1 | Batch_idx: 0 |  Loss: (0.9205) | Acc: (65.62%) (84/128)
Epoch: 1 | Batch_idx: 100 |  Loss: (0.8744) | Acc: (69.48%) (8982/12928)
Epoch: 1 | Batch_idx: 200 |  Loss: (0.8770) | Acc: (69.38%) (17850/25728)
Epoch: 1 | Batch_idx: 300 |  Loss: (0.8818) | Acc: (69.18%) (26652/38528)
Epoch: 2 | Batch_idx: 0 |  Loss: (0.7738) | Acc: (75.78%) (97/128)
Epoch: 2 | Batch_idx: 100 |  Loss: (0.8915) | Acc: (68.84%) (8899/12928)
Epoch: 2 | Batch_idx: 200 |  Loss: (0.8848) | Acc: (69.01%) (17755/25728)
Epoch: 2 | Batch_idx: 300 |  Loss: (0.8895) | Acc: (68.85%) (26528/38528)
Epoch: 3 | Batch_idx: 0 |  Loss: (0.9452) | Acc: (65.62%) (84/128)
Epoch: 3 | Batch_idx: 100 |  Loss: (0.8719) | Acc: (69.81%) (9025/129