In [1]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler


from torchvision import transforms, datasets

from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import time
import copy

In [2]:
model_base = timm.create_model('resnet18', pretrained=True, num_classes=2)
model_alpha = timm.create_model('resnet18', pretrained=True, num_classes=2)
model_beta = timm.create_model('resnet18', pretrained=True, num_classes=2)

In [3]:
config = resolve_data_config({}, model=model_base)
transform = create_transform(**config)



In [4]:
alpha_data_dir = './datasets/alpha/train/'
beta_data_dir = './datasets/beta/train/'
val_data_dir = './datasets/test/'
alpha_datasets = {
    "train": datasets.ImageFolder(alpha_data_dir, transform),
    "val": datasets.ImageFolder(val_data_dir, transform)
}
beta_datasets = {
    "train": datasets.ImageFolder(beta_data_dir, transform),
    "val": datasets.ImageFolder(val_data_dir, transform)
}
alpha_data_sizes = {x: len(alpha_datasets[x]) for x in ["train", "val"]}
beta_data_sizes = {x: len(beta_datasets[x]) for x in ["train", "val"]}

class_names = alpha_datasets["train"].classes
alpha_data_loader = {x: torch.utils.data.DataLoader(alpha_datasets[x], shuffle=True, batch_size=4, num_workers=0) for x in ['train', 'val']}
beta_data_loader = {x: torch.utils.data.DataLoader(beta_datasets[x], shuffle=True, batch_size=4, num_workers=0) for x in ['train', 'val']}

In [5]:
def train(model, criterion, optimizer, scheduler, num_epochs, data_loader, data_size, evaluate=False):
    since = time.time()
    
    best_model = copy.deepcopy(model.state_dict())
    best_acc = 0
    
    for epoch in range(num_epochs):
        print(f"Epoch : {epoch+1} / {num_epochs}")
        print("-"*15)
        
        for phase in ["train", "val"]:
            if phase == 'train':
                # evaluate가 Ture면 Train은 건너뜀~
                if evaluate:
                    continue
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0
            
            for inputs, labels in data_loader[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                optimizer.zero_grad()
                
                with torch.set_grad_enabled(phase=='train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)
                    
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()
                
            epoch_loss = running_loss / data_size[phase]
            epoch_acc = running_corrects / data_size[phase]
                
            print("{} Loss : {:.4f} Acc : {:.4f}".format(phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model = copy.deepcopy(model.state_dict())
    time_elapsed = time.time() - since
    print("Train Complete in {:.0f}m {:.4f}s".format(time_elapsed // 60, time_elapsed % 60))
    print("Best Acc : {:.4f}".format(best_acc))
    
    model.load_state_dict(best_model)
    return model

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_alpha = model_alpha.to(device)

criterion_alpha = nn.CrossEntropyLoss()
optimizer_alpha = optim.SGD(model_alpha.parameters(), lr=0.001, momentum=0.9)
scheduler_alpha = lr_scheduler.StepLR(optimizer_alpha, step_size=7, gamma=0.1)

model_beta = model_alpha.to(device)

criterion_beta = nn.CrossEntropyLoss()
optimizer_beta = optim.SGD(model_beta.parameters(), lr=0.001, momentum=0.9)
scheduler_beta = lr_scheduler.StepLR(optimizer_beta, step_size=7, gamma=0.1)

num_epochs=3

In [7]:
# Alpha 모델 학습
model_alpha = train(model_alpha, criterion_alpha, optimizer_alpha, scheduler_alpha, num_epochs, alpha_data_loader, alpha_data_sizes)

Epoch : 1 / 3
---------------
train Loss : 0.3581 Acc : 0.8750
val Loss : 0.3463 Acc : 0.8667
Epoch : 2 / 3
---------------
train Loss : 0.1367 Acc : 0.9625
val Loss : 0.3294 Acc : 0.8815
Epoch : 3 / 3
---------------
train Loss : 0.0800 Acc : 0.9750
val Loss : 0.1713 Acc : 0.9111
Train Complete in 2m 9.0517s
Best Acc : 0.9111


In [8]:
# Beta 모델 학습
model_beta = train(model_beta, criterion_beta, optimizer_beta, scheduler_beta, num_epochs, beta_data_loader,beta_data_sizes)

Epoch : 1 / 3
---------------
train Loss : 0.6713 Acc : 0.8125
val Loss : 0.7481 Acc : 0.6741
Epoch : 2 / 3
---------------
train Loss : 0.1652 Acc : 0.9250
val Loss : 0.0919 Acc : 0.9630
Epoch : 3 / 3
---------------
train Loss : 0.0100 Acc : 1.0000
val Loss : 0.3145 Acc : 0.8815
Train Complete in 2m 8.9272s
Best Acc : 0.9630


In [9]:
model_base = model_base.to(device)

criterion_base = nn.CrossEntropyLoss()
optimizer_base = optim.SGD(model_base.parameters(), lr=0.001, momentum=0.9)
scheduler_base = lr_scheduler.StepLR(optimizer_base, step_size=7, gamma=0.1)

In [10]:
# origin model
model_base = train(model_base, criterion_base, optimizer_base, scheduler_base, num_epochs, alpha_data_loader,alpha_data_sizes, evaluate=True)

Epoch : 1 / 3
---------------
val Loss : 0.7009 Acc : 0.5407
Epoch : 2 / 3
---------------
val Loss : 0.7009 Acc : 0.5407
Epoch : 3 / 3
---------------
val Loss : 0.7009 Acc : 0.5407
Train Complete in 0m 45.3903s
Best Acc : 0.5407


In [11]:
# Alpha와 Beta Weight 평균내기!
sd_alpha = model_alpha.state_dict()
sd_beta = model_beta.state_dict()
sd_new = dict()
for key in sd_alpha:
    sd_new[key] = (sd_beta[key] + sd_alpha[key]) / 2.

model_base.load_state_dict(sd_new)

<All keys matched successfully>

In [12]:
# new_model
model_base = train(model_base, criterion_base, optimizer_base, scheduler_base, num_epochs,alpha_data_loader, alpha_data_sizes, evaluate=True)

Epoch : 1 / 3
---------------
val Loss : 0.0919 Acc : 0.9630
Epoch : 2 / 3
---------------
val Loss : 0.0919 Acc : 0.9630
Epoch : 3 / 3
---------------
val Loss : 0.0919 Acc : 0.9630
Train Complete in 0m 45.5853s
Best Acc : 0.9630
