# DML 코드 practice

### dataset : CIFAR100
### Teacher model : MobileNet
### Student model : ResNet32 
### Ditillation method : DML (Deep mutual Learning)


# step 1 : 필요한 라이브러리 로드

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm  # Ensure you are importing the function from the module


# Step 2 : CIFAR100 데이터셋 로드

In [2]:
# Data augmentation and normalization for training set
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

# Only normalization for test set
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

# Download CIFAR-100 dataset
train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=8)


Files already downloaded and verified
Files already downloaded and verified


# step 3 : Model definition

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# For reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x74b01c20a9d0>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# BasicBlock 클래스 정의 (ResNet의 기본 구성 요소)
class BasicBlock(nn.Module):
    expansion = 1  # 확장 계수 (필터의 크기를 변경하는 데 사용)

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        # 첫 번째 3x3 컨볼루션 레이어: 입력 채널(in_planes) -> 출력 채널(planes), stride와 padding 설정
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)  # 첫 번째 Batch Normalization 레이어
        # 두 번째 3x3 컨볼루션 레이어: 출력 채널 유지
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)  # 두 번째 Batch Normalization 레이어

        # Shortcut 경로: 입력과 출력의 크기가 다르면 크기를 조정하는 레이어 추가
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:  # 크기 또는 채널 수가 다를 경우
            self.shortcut = nn.Sequential(
                # 1x1 컨볼루션: 입력 채널(in_planes) -> 출력 채널(planes * expansion), stride 적용
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)  # Batch Normalization
            )

    def forward(self, x):
        # 입력 데이터에 대해 첫 번째 컨볼루션과 ReLU 활성화 함수 적용
        out = F.relu(self.bn1(self.conv1(x)))
        # 두 번째 컨볼루션 및 Batch Normalization 적용
        out = self.bn2(self.conv2(out))
        # Shortcut 경로를 통해 입력과 현재 출력 더하기
        out += self.shortcut(x)
        # 다시 ReLU 활성화 함수 적용
        out = F.relu(out)
        return out


# ResNet32 모델 정의
class ResNet32(nn.Module):
    def __init__(self, num_classes=100):
        super(ResNet32, self).__init__()
        self.in_planes = 16  # 입력 채널 크기 초기화

        # 첫 번째 컨볼루션 레이어와 Batch Normalization
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)  # 3채널(RGB) -> 16채널로 변환
        self.bn1 = nn.BatchNorm2d(16)

        # 세 개의 레이어 블록 생성 (각 블록은 여러 BasicBlock으로 구성)
        # layer1: 채널 16 유지, 5개의 블록, stride=1
        self.layer1 = self._make_layer(BasicBlock, 16, 5, stride=1)
        # layer2: 채널 32로 증가, 5개의 블록, stride=2 (공간 크기 절반으로 감소)
        self.layer2 = self._make_layer(BasicBlock, 32, 5, stride=2)
        # layer3: 채널 64로 증가, 5개의 블록, stride=2
        self.layer3 = self._make_layer(BasicBlock, 64, 5, stride=2)

        # Adaptive Average Pooling (출력 크기를 1x1로 조정)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        # Fully Connected 레이어 (64채널 -> 클래스 수(num_classes) 출력)
        self.fc = nn.Linear(64 * BasicBlock.expansion, num_classes)

    # 블록 생성 함수: 특정 채널 크기와 블록 수로 BasicBlock을 연결
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)  # 첫 블록만 stride 적용
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))  # BasicBlock 추가
            self.in_planes = planes * block.expansion  # 다음 블록의 입력 크기 업데이트
        return nn.Sequential(*layers)

    def forward(self, x):
        # 첫 번째 컨볼루션과 BatchNorm
        out = F.relu(self.bn1(self.conv1(x)))
        # 세 개의 레이어 블록 통과
        out = self.layer1(out)  # 첫 번째 레이어 블록
        out = self.layer2(out)  # 두 번째 레이어 블록
        out = self.layer3(out)  # 세 번째 레이어 블록
        # Adaptive Average Pooling 적용
        out = self.avgpool(out)
        # 1D 벡터로 변환
        out = out.view(out.size(0), -1)
        # Fully Connected 레이어로 출력
        out = self.fc(out)
        return out


# 모델 초기화
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # GPU 사용 여부 확인

# MobileNetV2 모델 로드 (PyTorch의 미리 학습된 모델 사용)
mobilenet_model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', num_classes=100).to(device)

# ResNet32 모델 인스턴스 생성
resnet32_model = ResNet32(num_classes=100).to(device)

# 모델 정보 출력
print(mobilenet_model)  # MobileNetV2 구조 출력
print(resnet32_model)   # ResNet32 구조 출력


MobileNetV2(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU6(inplace=True)
    )
    (1): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
          (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU6(inplace=True)
        )
        (1): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): InvertedResidual(
      (conv): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (1): BatchNorm2d(96, eps=

Using cache found in /home/park/.cache/torch/hub/pytorch_vision_v0.10.0


# step 4 : loss function & optimizer 정의 

In [5]:
# Cross-entropy loss
criterion = nn.CrossEntropyLoss()

# Optimizers for both models
optimizer_mobile = optim.Adam(mobilenet_model.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=5e-4)
optimizer_resnet = optim.Adam(resnet32_model.parameters(), lr=0.001, betas=(0.9, 0.999), weight_decay=5e-4)

# Learning rate scheduler
scheduler_mobile = optim.lr_scheduler.StepLR(optimizer_mobile, step_size=30, gamma=0.01)
scheduler_resnet = optim.lr_scheduler.StepLR(optimizer_resnet, step_size=30, gamma=0.01)


# step 5 : DML & Loss function 

In [6]:
def mutual_learning_loss(output1, output2):
    kl_loss = nn.KLDivLoss(reduction='batchmean')
    return kl_loss(nn.functional.log_softmax(output1, dim=1), nn.functional.softmax(output2, dim=1)) + \
           kl_loss(nn.functional.log_softmax(output2, dim=1), nn.functional.softmax(output1, dim=1))


# step 6 : ResNet32 before distillation

In [7]:
def train_resnet32(epoch, model, optimizer):
    model.train()
    
    total_loss = 0
    correct = 0
    total = 0
    
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch}')
    
    for batch_idx, (inputs, targets) in progress_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Statistics
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        total_loss += loss.item()
        
        progress_bar.set_postfix({
            'Loss': f'{total_loss/(batch_idx+1):.4f}',
            'Accuracy': f'{correct/total*100:.2f}%'
        })
# Evaluation function for ResNet32
def test_resnet32(epoch, model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()
    progress_bar = tqdm(enumerate(test_loader), total=len(test_loader), desc=f"Epoch {epoch} [Testing ResNet32]")

    with torch.no_grad():
        for batch_idx, (inputs, targets) in progress_bar:
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            progress_bar.set_postfix({
                'Test Loss': f'{test_loss/(batch_idx+1):.4f}',
                'Accuracy': f'{correct/total*100:.2f}%'
            })
    
    return correct, total


# step 7 : training MobileNet and ResNet32 with DML

In [8]:

def train(epoch, model1, model2, optimizer1, optimizer2):
    model1.train()
    model2.train()
    
    total_loss1, total_loss2 = 0, 0
    correct1, correct2 = 0, 0
    total = 0
    
    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Epoch {epoch}')
    
    for batch_idx, (inputs, targets) in progress_bar:
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Forward pass
        outputs1 = model1(inputs)
        outputs2 = model2(inputs)
        
        # Loss calculation
        loss1 = criterion(outputs1, targets)
        loss2 = criterion(outputs2, targets)
        ml_loss = mutual_learning_loss(outputs1, outputs2)
        
        total_loss = loss1 + loss2 + ml_loss
        
        # Backpropagation
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        total_loss.backward()
        
        # Optimizer step
        optimizer1.step()
        optimizer2.step()
        
        # Statistics
        _, predicted1 = outputs1.max(1)
        _, predicted2 = outputs2.max(1)
        
        total += targets.size(0)
        correct1 += predicted1.eq(targets).sum().item()
        correct2 += predicted2.eq(targets).sum().item()
        
        total_loss1 += loss1.item()
        total_loss2 += loss2.item()
        
        # Update progress bar with current statistics
        progress_bar.set_postfix({
            'MobileNet Loss': f'{total_loss1/(batch_idx+1):.4f}',
            'ResNet32 Loss': f'{total_loss2/(batch_idx+1):.4f}',
            'MobileNet Acc': f'{correct1/total*100:.2f}%',
            'ResNet32 Acc': f'{correct2/total*100:.2f}%'
        })


# step 7 : testing loop

In [9]:
def test(epoch, model):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0
    
    progress_bar = tqdm(enumerate(test_loader), total=len(test_loader), desc=f'Test Epoch {epoch}')
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in progress_bar:
            inputs, targets = inputs.to(device), targets.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Update progress bar with current statistics
            progress_bar.set_postfix({
                'Test Loss': f'{test_loss/(batch_idx+1):.4f}',
                'Accuracy': f'{correct/total*100:.2f}%'
            })


In [10]:
# Step 1: Train ResNet32 Independently (Before Distillation)
resnet32_model_independent = ResNet32(num_classes=100).to(device)
optimizer_resnet_independent = optim.SGD(resnet32_model_independent.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
scheduler_resnet_independent = optim.lr_scheduler.StepLR(optimizer_resnet_independent, step_size=30, gamma=0.1)

print("Training ResNet32 before distillation...")
for epoch in range(150):
    train_resnet32(epoch, resnet32_model_independent, optimizer_resnet_independent)
    correct, total = test_resnet32(epoch, resnet32_model_independent)
    scheduler_resnet_independent.step()

# Store the final accuracy of the independent ResNet32
resnet32_independent_accuracy = correct / total * 100
# Step 2: Train MobileNet and ResNet32 with DML (After Distillation)
print("Training MobileNet and ResNet32 with DML...")
for epoch in range(150):
    train(epoch, mobilenet_model, resnet32_model, optimizer_mobile, optimizer_resnet)
    correct, total = test_resnet32(epoch, resnet32_model)  # Evaluate ResNet32 after DML
    scheduler_mobile.step()
    scheduler_resnet.step()

# Store the final accuracy of the ResNet32 after DML
resnet32_dml_accuracy = correct / total * 100

# Print comparison of the independent ResNet32 and the DML-trained ResNet32
print(f'Accuracy Comparison:\nResNet32 Independent: {resnet32_independent_accuracy:.2f}%\nResNet32 after DML: {resnet32_dml_accuracy:.2f}%')


Training ResNet32 before distillation...


Epoch 0: 100%|██████████| 391/391 [00:03<00:00, 102.48it/s, Loss=4.4115, Accuracy=4.15%]
Epoch 0 [Testing ResNet32]: 100%|██████████| 100/100 [00:00<00:00, 276.55it/s, Test Loss=4.1783, Accuracy=6.20%]
Epoch 1: 100%|██████████| 391/391 [00:03<00:00, 108.48it/s, Loss=4.0208, Accuracy=8.76%]
Epoch 1 [Testing ResNet32]: 100%|██████████| 100/100 [00:00<00:00, 334.44it/s, Test Loss=3.8520, Accuracy=11.30%]
Epoch 2: 100%|██████████| 391/391 [00:03<00:00, 105.41it/s, Loss=3.7660, Accuracy=12.03%]
Epoch 2 [Testing ResNet32]: 100%|██████████| 100/100 [00:00<00:00, 292.17it/s, Test Loss=3.6618, Accuracy=14.12%]
Epoch 3: 100%|██████████| 391/391 [00:03<00:00, 109.23it/s, Loss=3.5948, Accuracy=14.72%]
Epoch 3 [Testing ResNet32]: 100%|██████████| 100/100 [00:00<00:00, 321.20it/s, Test Loss=3.5048, Accuracy=16.29%]
Epoch 4: 100%|██████████| 391/391 [00:03<00:00, 106.47it/s, Loss=3.4512, Accuracy=16.94%]
Epoch 4 [Testing ResNet32]: 100%|██████████| 100/100 [00:00<00:00, 307.51it/s, Test Loss=3.4178, 

Training MobileNet and ResNet32 with DML...



Epoch 0: 100%|██████████| 391/391 [00:06<00:00, 61.78it/s, MobileNet Loss=4.2244, ResNet32 Loss=4.1053, MobileNet Acc=4.76%, ResNet32 Acc=6.76%]
Epoch 0 [Testing ResNet32]: 100%|██████████| 100/100 [00:00<00:00, 307.31it/s, Test Loss=3.7796, Accuracy=11.09%]
Epoch 1: 100%|██████████| 391/391 [00:06<00:00, 60.95it/s, MobileNet Loss=3.8250, ResNet32 Loss=3.6699, MobileNet Acc=9.94%, ResNet32 Acc=12.96%]
Epoch 1 [Testing ResNet32]: 100%|██████████| 100/100 [00:00<00:00, 303.58it/s, Test Loss=3.5173, Accuracy=15.44%]
Epoch 2: 100%|██████████| 391/391 [00:06<00:00, 62.15it/s, MobileNet Loss=3.6291, ResNet32 Loss=3.4068, MobileNet Acc=12.94%, ResNet32 Acc=17.96%]
Epoch 2 [Testing ResNet32]: 100%|██████████| 100/100 [00:00<00:00, 279.71it/s, Test Loss=3.2261, Accuracy=21.14%]
Epoch 3: 100%|██████████| 391/391 [00:06<00:00, 61.57it/s, MobileNet Loss=3.4549, ResNet32 Loss=3.1985, MobileNet Acc=16.05%, ResNet32 Acc=22.03%]
Epoch 3 [Testing ResNet32]: 100%|██████████| 100/100 [00:00<00:00, 267.6

Accuracy Comparison:
ResNet32 Independent: 52.73%
ResNet32 after DML: 56.77%



