In [None]:
import timm
from DataLoader import AOIDataset
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import FashionMNIST
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from tqdm import tqdm
import warnings

warnings.filterwarnings(action='ignore')

### wandb 초기화 

In [None]:
run = wandb.init(project='resnet18_evaluation', name='aoi')

### 하이퍼파라미터 선언

In [None]:
epochs = 50
lr = 1e-3
batch_size = 64
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
torch.cuda.is_available()

In [None]:
wandb.config.epochs = epochs
wandb.config.lr = lr
wandb.config.batch_size = batch_size

### 데이터 정의

In [None]:
# data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512, 512))
])

train_data = AOIDataset(train=True, transform=transform)
val_data = AOIDataset(val=True, transform=transform)
test_data = AOIDataset(test=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=4)


In [None]:
"""
# MNISTdata
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((512, 512))
])

total_data = FashionMNIST('./data', train=True, transform=transform, download=True)
train_data, val_data = random_split(total_data, [int(len(total_data)*0.8), int(len(total_data)*0.2)])
test_data = FashionMNIST('./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=64, shuffle=True, num_workers=4)
test_loader = DataLoader(test_data, batch_size=64, shuffle=True, num_workers=4)
"""

### 모델 및 기타 학습용 객체 정의

In [None]:
model = timm.create_model('resnet18', pretrained=True, num_classes=7).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.NLLLoss()

In [None]:
# resnet 18의 특성 추출용 CNN layer의 파라미터는 고정시키고 분류기의 성능만을 업데이트하고자 한다. 
for param in model.parameters() :
    param.require_grad = False

for param in model.fc.parameters() :
    param.require_grad = True

### 학습 함수 정의

In [None]:
def train() : 
    model.train()
    epoch_loss = 0  

    print('training process')
    for _, (data, label) in enumerate(train_loader) :
        data = data.to(device)
        label = label.to(device)

        logit = model(data)
        output = F.log_softmax(logit) 
        loss = criterion(output, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    train_loss = epoch_loss / len(train_loader.dataset)
        
    return train_loss


def validation() :
    model.eval()
    epoch_loss = 0

    print('valiation process')
    for data, label in val_loader :
        data = data.to(device)
        label = label.to(device)
        
        logit = model(data)
        output = F.log_softmax(logit) 
        loss = criterion(output, label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
    
    val_loss = epoch_loss / len(val_loader.dataset)
        
    return val_loss

### 학습 진행

In [None]:
for epoch in range(epochs) :
    print('==================={}=================='.format(epoch))
    train_loss = train()
    run.log({'epoch' : epoch, 'train_loss' : train_loss})

    val_loss = validation()
    run.log({'epoch' : epoch, 'val_loss' : val_loss})

    print('epoch : {} train_loss : {:.4f} val_loss : {:.4f}'.format(epoch, train_loss, val_loss))

In [None]:
torch.save(model, 'model/first.pt')

In [None]:
run.finish()