In [None]:
from __future__ import print_function

import glob, os, random, torch, timm, shutil, pickle, time
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl

from PIL import Image, ImageEnhance, ImageOps
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from pathlib import Path
from pprint import pprint
from tempfile import TemporaryDirectory
from torch.cuda.amp import autocast, GradScaler

os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device} device")

seed = 0

%matplotlib inline



In [None]:
print(torch.__version__)
print(torch.cuda.is_available())
print(torch.cuda.device_count())

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [None]:
###############################################################################
# class imbalance補正
###############################################################################

In [None]:
orig_dir = '/tf/notebooks/EffNet/Corrected_YOLO_data_123_predict_base456_for_EffNet_train_val/JCO123_corrected_converted_to_YOLO_thacher'
train_dir = os.path.join(orig_dir, 'train')
data_1 = sorted(glob.glob(train_dir + '/*'))
data_1

In [None]:
save_name = '20240415_JCO123_EV2T'

In [None]:
annotated_group = []
annotated_len =[]
for i in data_1:
    temp1 = glob.glob(i + '/*.png')
    annotated_group.append(os.path.basename(i))
    annotated_len.append(len(temp1))
    
temp_dict = dict(zip(annotated_group,annotated_len))
temp_dict

In [None]:
sum(annotated_len)

In [None]:
WEIGHT = 1/np.array(annotated_len)*sum(annotated_len)/len(annotated_len)
WEIGHT

In [None]:
CLASS_WEIGHT = dict(zip(list(range(len(annotated_group))), WEIGHT))
CLASS_WEIGHT

In [None]:
WEIGHT = torch.tensor(WEIGHT).cuda()
WEIGHT

In [None]:
NUM_CLASSES = len(annotated_group)
IMG_SIZE = 224
BATCH_SIZE = 256
WORKERS = 32
EPOCHS = 50
LR = 3e-5

In [None]:
#Base model 	resolution
#EfficientNetB0 	224
#EfficientNetB1 	240
#EfficientNetB2 	260
#EfficientNetB3 	300
#EfficientNetB4 	380
#EfficientNetB5 	456
#EfficientNetB6 	528
#EfficientNetB7 	600

#Efficientnet V2 S 	384
#Efficientnet V2 M 	480 
#Efficientnet V2 L 	480 
#Efficientnetv2 B0 	224 
#Efficientnetv2 B1 	240 
#Efficientnetv2 B2 	260 
#Efficientnetv2 B3 	300 

In [None]:
# file save temporaly
temp_na = 'best_model_.pt'

In [None]:
# Initialize early stopping parameters
patience = 5

In [None]:
#####

In [None]:
train_dataset_dir = Path(orig_dir + '/train')
val_dataset_dir = Path(orig_dir + '/val')

In [None]:
files = glob.glob(str(train_dataset_dir) + '/*/*.png')
random_idx = np.random.randint(1, len(files), size=9)
fig, axes = plt.subplots(3, 3, figsize=(8, 6))

for idx, ax in zip(random_idx, axes.ravel()):
    img = Image.open(files[idx])
    ax.imshow(img)

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandAugment(num_ops=3, magnitude=7), 
        transforms.ToTensor()]),
    'val': transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor()])
}

In [None]:
data_dir = orig_dir
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
image_datasets

In [None]:
dataset_sizes

In [None]:
device

In [None]:
for i, j in enumerate(dataloaders['train']):
    if i == 1:
        break
        
print(j[0].size())

temp = j[0].to('cpu').detach().numpy().copy()
print('max: ', temp.max())
print('min: ', temp.min())
print('mean: ', temp.mean())
print('std: ', temp.std())

In [None]:
random_idx = np.random.randint(1, BATCH_SIZE, size=9)
fig, axes = plt.subplots(3, 3, figsize=(8, 6))

for idx, ax in zip(random_idx, axes.ravel()):
    img = transforms.ToPILImage()(j[0][idx]) 
    ax.imshow(img)

In [None]:
#####

In [None]:
model_names = timm.list_models(pretrained=True)
pprint(model_names)

In [None]:
class MyImgModel(nn.Module):
    
    def __init__(self, 
                 model_name: str, pretrained: bool, 
                 hidden_dim: int, out_dim: int):
        super().__init__()
        self.backbone = timm.create_model(model_name,
                                         pretrained=pretrained,
                                         num_classes=0)
        self.in_features =self.backbone.num_features
        self.head = nn.Sequential(nn.Dropout(p=0.2),
                                  nn.Linear(self.in_features, hidden_dim),
                                  nn.ReLU(),
                                  nn.Dropout(p=0.2),
                                  nn.Linear(hidden_dim, out_dim))
    def forward(self, x):
        h = self.backbone(x)
        y = self.head(h)
        return y

In [None]:
model_selected = 'efficientnetv2_rw_t.ra2_in1k'

In [None]:
model = MyImgModel(model_name=model_selected, pretrained=True, hidden_dim=256, out_dim=NUM_CLASSES)
model.to("cuda:0")

In [None]:
model.state_dict()

In [None]:
## load weight
#model.load_state_dict(torch.load('./XXX'))

In [None]:
model.state_dict()

In [None]:
model_check = timm.create_model(model_name=model_selected, pretrained=True, num_classes=NUM_CLASSES)
model_check.default_cfg

In [None]:
model_check.num_features

In [None]:
model_check.feature_info

In [None]:
for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    print(n,': ', i[1].requires_grad,': ', i[0][0])

In [None]:
model.state_dict()

In [None]:
######################

In [None]:
# loss function
criterion = nn.CrossEntropyLoss(weight = WEIGHT).to(torch.float)
# optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)
# scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=3, verbose=True)

In [None]:
layer = 436

for i, param in enumerate(model.parameters()):
    param.requires_grad = False
    
for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    if (n >= layer) and ('bn' not in i[0][0]): 
        i[1].requires_grad = True

for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    print(n,': ', i[1].requires_grad,': ', i[0][0])

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs):
    since = time.time()
    
    # Initialize GradScaler for mixed precision
    scaler = GradScaler()

    # Initialize early stopping parameters
    counter = 0
    best_val_loss = float('inf')
    early_stop = False
    
    filename_prefix = save_name + "_1st_model.f_"
    
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, "best_model.pt")
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                running_loss = 0.0
                running_corrects = 0

                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()

                    # Use autocast to enable mixed precision
                    with autocast():
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                    if phase == 'train':
                        # Scale the loss and perform backpropagation
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                if phase == 'val':
                    scheduler.step(epoch_loss)
                    
                    if epoch_loss < best_val_loss:
                        print(f'Validation loss decreased ({best_val_loss:.4f} --> {epoch_loss:.4f}).  Saving model ...')
                        best_val_loss = epoch_loss
                        torch.save(model.state_dict(), best_model_params_path)
                        
                        filepath = f'{filename_prefix}{phase}_ep{epoch:02d}_loss{epoch_loss:.4f}_acc{epoch_acc:.4f}.pt'
                        torch.save(model.state_dict(), filepath)
                        
                        counter = 0
                    else:
                        counter += 1
                        print(f'EarlyStopping counter: {counter} out of {patience}')
                        if counter >= patience:
                            early_stop = True

                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)
                    
                    filepath = f'{filename_prefix}{phase}_ep{epoch:02d}_loss{epoch_loss:.4f}_acc{epoch_acc:.4f}.pt'
                    torch.save(model.state_dict(), filepath)
                    
                    print('lr: ', optimizer.state_dict()['param_groups'][0]['lr'])
                    
            if early_stop:
                print("Early stopping")
                break

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

        model.load_state_dict(torch.load(best_model_params_path))

    return model


In [None]:
train_model(model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, num_epochs=EPOCHS)

In [None]:
model.state_dict()

In [None]:
######

In [None]:
layer = 251

for i, param in enumerate(model.parameters()):
    param.requires_grad = False
    
for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    if (n >= layer) and ('bn' not in i[0][0]): 
        i[1].requires_grad = True

for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    print(n,': ', i[1].requires_grad,': ', i[0][0])
    
criterion = nn.CrossEntropyLoss(weight = WEIGHT).to(torch.float)
# optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)
# scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=3, verbose=True)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs):
    since = time.time()
    
    # Initialize GradScaler for mixed precision
    scaler = GradScaler()

    # Initialize early stopping parameters
    counter = 0
    best_val_loss = float('inf')
    early_stop = False
    
    filename_prefix = save_name + "_2nd_model.f_"
    
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, "best_model.pt")
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                running_loss = 0.0
                running_corrects = 0

                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()

                    # Use autocast to enable mixed precision
                    with autocast():
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                    if phase == 'train':
                        # Scale the loss and perform backpropagation
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                if phase == 'val':
                    scheduler.step(epoch_loss)
                    
                    if epoch_loss < best_val_loss:
                        print(f'Validation loss decreased ({best_val_loss:.4f} --> {epoch_loss:.4f}).  Saving model ...')
                        best_val_loss = epoch_loss
                        torch.save(model.state_dict(), best_model_params_path)
                        
                        filepath = f'{filename_prefix}{phase}_ep{epoch:02d}_loss{epoch_loss:.4f}_acc{epoch_acc:.4f}.pt'
                        torch.save(model.state_dict(), filepath)
                        
                        counter = 0
                    else:
                        counter += 1
                        print(f'EarlyStopping counter: {counter} out of {patience}')
                        if counter >= patience:
                            early_stop = True

                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)
                    
                    filepath = f'{filename_prefix}{phase}_ep{epoch:02d}_loss{epoch_loss:.4f}_acc{epoch_acc:.4f}.pt'
                    torch.save(model.state_dict(), filepath)
                    
                    print('lr: ', optimizer.state_dict()['param_groups'][0]['lr'])
                    
            if early_stop:
                print("Early stopping")
                break

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

        model.load_state_dict(torch.load(best_model_params_path))

    return model


In [None]:
train_model(model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, num_epochs=EPOCHS)

In [None]:
model.state_dict()

In [None]:
#####

In [None]:
layer = 134

for i, param in enumerate(model.parameters()):
    param.requires_grad = False
    
for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    if (n >= layer) and ('bn' not in i[0][0]): 
        i[1].requires_grad = True

for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    print(n,': ', i[1].requires_grad,': ', i[0][0])
    
criterion = nn.CrossEntropyLoss(weight = WEIGHT).to(torch.float)
# optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)
# scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=3, verbose=True)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs):
    since = time.time()
    
    # Initialize GradScaler for mixed precision
    scaler = GradScaler()

    # Initialize early stopping parameters
    counter = 0
    best_val_loss = float('inf')
    early_stop = False
    
    filename_prefix = save_name + "_3rd_model.f_"
    
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, "best_model.pt")
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                running_loss = 0.0
                running_corrects = 0

                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()

                    # Use autocast to enable mixed precision
                    with autocast():
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                    if phase == 'train':
                        # Scale the loss and perform backpropagation
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                if phase == 'val':
                    scheduler.step(epoch_loss)
                    
                    if epoch_loss < best_val_loss:
                        print(f'Validation loss decreased ({best_val_loss:.4f} --> {epoch_loss:.4f}).  Saving model ...')
                        best_val_loss = epoch_loss
                        torch.save(model.state_dict(), best_model_params_path)
                        
                        filepath = f'{filename_prefix}{phase}_ep{epoch:02d}_loss{epoch_loss:.4f}_acc{epoch_acc:.4f}.pt'
                        torch.save(model.state_dict(), filepath)
                        
                        counter = 0
                    else:
                        counter += 1
                        print(f'EarlyStopping counter: {counter} out of {patience}')
                        if counter >= patience:
                            early_stop = True

                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)
                    
                    filepath = f'{filename_prefix}{phase}_ep{epoch:02d}_loss{epoch_loss:.4f}_acc{epoch_acc:.4f}.pt'
                    torch.save(model.state_dict(), filepath)
                    
                    print('lr: ', optimizer.state_dict()['param_groups'][0]['lr'])
                    
            if early_stop:
                print("Early stopping")
                break

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

        model.load_state_dict(torch.load(best_model_params_path))

    return model


In [None]:
train_model(model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, num_epochs=EPOCHS)

In [None]:
#####

In [None]:
layer = 56

for i, param in enumerate(model.parameters()):
    param.requires_grad = False
    
for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    if (n >= layer) and ('bn' not in i[0][0]): 
        i[1].requires_grad = True

for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    print(n,': ', i[1].requires_grad,': ', i[0][0])
    
criterion = nn.CrossEntropyLoss(weight = WEIGHT).to(torch.float)
# optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)
# scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=3, verbose=True)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs):
    since = time.time()
    
    # Initialize GradScaler for mixed precision
    scaler = GradScaler()

    # Initialize early stopping parameters
    counter = 0
    best_val_loss = float('inf')
    early_stop = False
    
    filename_prefix = save_name + "_4th_model.f_"
    
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, "best_model.pt")
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                running_loss = 0.0
                running_corrects = 0

                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()

                    # Use autocast to enable mixed precision
                    with autocast():
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                    if phase == 'train':
                        # Scale the loss and perform backpropagation
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                if phase == 'val':
                    scheduler.step(epoch_loss)
                    
                    if epoch_loss < best_val_loss:
                        print(f'Validation loss decreased ({best_val_loss:.4f} --> {epoch_loss:.4f}).  Saving model ...')
                        best_val_loss = epoch_loss
                        torch.save(model.state_dict(), best_model_params_path)
                        
                        filepath = f'{filename_prefix}{phase}_ep{epoch:02d}_loss{epoch_loss:.4f}_acc{epoch_acc:.4f}.pt'
                        torch.save(model.state_dict(), filepath)
                        
                        counter = 0
                    else:
                        counter += 1
                        print(f'EarlyStopping counter: {counter} out of {patience}')
                        if counter >= patience:
                            early_stop = True

                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)
                    
                    filepath = f'{filename_prefix}{phase}_ep{epoch:02d}_loss{epoch_loss:.4f}_acc{epoch_acc:.4f}.pt'
                    torch.save(model.state_dict(), filepath)
                    
                    print('lr: ', optimizer.state_dict()['param_groups'][0]['lr'])
                    
            if early_stop:
                print("Early stopping")
                break

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

        model.load_state_dict(torch.load(best_model_params_path))

    return model


In [None]:
train_model(model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, num_epochs=EPOCHS)

In [None]:
#####

In [None]:
layer = 32

for i, param in enumerate(model.parameters()):
    param.requires_grad = False
    
for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    if (n >= layer) and ('bn' not in i[0][0]): 
        i[1].requires_grad = True

for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    print(n,': ', i[1].requires_grad,': ', i[0][0])
    
criterion = nn.CrossEntropyLoss(weight = WEIGHT).to(torch.float)
# optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)
# scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=3, verbose=True)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs):
    since = time.time()
    
    # Initialize GradScaler for mixed precision
    scaler = GradScaler()

    # Initialize early stopping parameters
    counter = 0
    best_val_loss = float('inf')
    early_stop = False
    
    filename_prefix = save_name + "_5th_model.f_"
    
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, "best_model.pt")
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                running_loss = 0.0
                running_corrects = 0

                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()

                    # Use autocast to enable mixed precision
                    with autocast():
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                    if phase == 'train':
                        # Scale the loss and perform backpropagation
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                if phase == 'val':
                    scheduler.step(epoch_loss)
                    
                    if epoch_loss < best_val_loss:
                        print(f'Validation loss decreased ({best_val_loss:.4f} --> {epoch_loss:.4f}).  Saving model ...')
                        best_val_loss = epoch_loss
                        torch.save(model.state_dict(), best_model_params_path)
                        
                        filepath = f'{filename_prefix}{phase}_ep{epoch:02d}_loss{epoch_loss:.4f}_acc{epoch_acc:.4f}.pt'
                        torch.save(model.state_dict(), filepath)
                        
                        counter = 0
                    else:
                        counter += 1
                        print(f'EarlyStopping counter: {counter} out of {patience}')
                        if counter >= patience:
                            early_stop = True

                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)
                    
                    filepath = f'{filename_prefix}{phase}_ep{epoch:02d}_loss{epoch_loss:.4f}_acc{epoch_acc:.4f}.pt'
                    torch.save(model.state_dict(), filepath)
                    
                    print('lr: ', optimizer.state_dict()['param_groups'][0]['lr'])
                    
            if early_stop:
                print("Early stopping")
                break

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

        model.load_state_dict(torch.load(best_model_params_path))

    return model


In [None]:
train_model(model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, num_epochs=EPOCHS)

In [None]:
#####

In [None]:
layer = 8

for i, param in enumerate(model.parameters()):
    param.requires_grad = False
    
for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    if (n >= layer) and ('bn' not in i[0][0]): 
        i[1].requires_grad = True

for n, i in enumerate(zip(model.named_parameters(), model.parameters())):
    print(n,': ', i[1].requires_grad,': ', i[0][0])
    
criterion = nn.CrossEntropyLoss(weight = WEIGHT).to(torch.float)
# optimizer
optimizer = optim.Adam(model.parameters(), lr=LR)
# scheduler
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=3, verbose=True)

In [None]:
def train_model(model, criterion, optimizer, scheduler, num_epochs):
    since = time.time()
    
    # Initialize GradScaler for mixed precision
    scaler = GradScaler()

    # Initialize early stopping parameters
    counter = 0
    best_val_loss = float('inf')
    early_stop = False
    
    filename_prefix = save_name + "_6th_model.f_"
    
    with TemporaryDirectory() as tempdir:
        best_model_params_path = os.path.join(tempdir, "best_model.pt")
        best_acc = 0.0

        for epoch in range(num_epochs):
            print(f'Epoch {epoch}/{num_epochs - 1}')
            print('-' * 10)

            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()
                else:
                    model.eval()

                running_loss = 0.0
                running_corrects = 0

                for inputs, labels in dataloaders[phase]:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()

                    # Use autocast to enable mixed precision
                    with autocast():
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                    if phase == 'train':
                        # Scale the loss and perform backpropagation
                        scaler.scale(loss).backward()
                        scaler.step(optimizer)
                        scaler.update()

                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)

                epoch_loss = running_loss / dataset_sizes[phase]
                epoch_acc = running_corrects.double() / dataset_sizes[phase]

                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

                if phase == 'val':
                    scheduler.step(epoch_loss)
                    
                    if epoch_loss < best_val_loss:
                        print(f'Validation loss decreased ({best_val_loss:.4f} --> {epoch_loss:.4f}).  Saving model ...')
                        best_val_loss = epoch_loss
                        torch.save(model.state_dict(), best_model_params_path)
                        
                        filepath = f'{filename_prefix}{phase}_ep{epoch:02d}_loss{epoch_loss:.4f}_acc{epoch_acc:.4f}.pt'
                        torch.save(model.state_dict(), filepath)
                        
                        counter = 0
                    else:
                        counter += 1
                        print(f'EarlyStopping counter: {counter} out of {patience}')
                        if counter >= patience:
                            early_stop = True

                if phase == 'val' and epoch_acc > best_acc:
                    best_acc = epoch_acc
                    torch.save(model.state_dict(), best_model_params_path)
                    
                    filepath = f'{filename_prefix}{phase}_ep{epoch:02d}_loss{epoch_loss:.4f}_acc{epoch_acc:.4f}.pt'
                    torch.save(model.state_dict(), filepath)
                    
                    print('lr: ', optimizer.state_dict()['param_groups'][0]['lr'])
                    
            if early_stop:
                print("Early stopping")
                break

            print()

        time_elapsed = time.time() - since
        print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
        print(f'Best val Acc: {best_acc:4f}')

        model.load_state_dict(torch.load(best_model_params_path))

    return model


In [None]:
train_model(model=model, criterion=criterion, optimizer=optimizer, scheduler=scheduler, num_epochs=EPOCHS)