In [None]:
import torch
import timm

torch.__version__

In [None]:
import os
import sys
import openpyxl
import openpyxl.workbook.workbook as WB
import matplotlib.pyplot as plt

In [None]:
from torch.utils.data import Dataset, Subset, DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
import torchvision
import cv2

class CellDataset(Dataset):
    def __init__(self, xlsx_path="dataset.xlsx", root="data", transform=None, num_cls=3):
        xl = openpyxl.load_workbook(xlsx_path, data_only=True)
        sht = xl.active
        self.root = root
        self.transform = transform
        self.ox = []
        self.num_cls = num_cls
        for i in range(5, 605, 2):
            self.ox.append([sht["i"+str(i)].value, sht["j"+str(i)].value])


    def __getitem__(self, idx):
        labels = self.ox[idx]
        start_image_path = os.path.join(self.root,f"{str(idx*2+1)}.jpg")
        end_image_path = os.path.join(self.root,f"{str(idx*2+2)}.jpg")
    
        start_image = cv2.imread(start_image_path)
        end_image = cv2.imread(end_image_path)
        
        if self.transform:
            start_image = self.transform(start_image)
            end_image = self.transform(end_image)
        
        if self.num_cls == 3:
            if labels[0] == 'X' and labels[1] == 'X':
                target = 0
            elif labels[0] == 'O' and labels[1] == 'X':
                target = 1
            elif labels[0] == 'X' and labels[1] == 'O':
                target = 1
            elif labels[0] == 'O' and labels[1] == 'O':
                target = 2
        else:
            if labels[0] == 'X' and labels[1] == 'X':
                target = 0
            elif labels[0] == 'O' and labels[1] == 'X':
                target = 1
            elif labels[0] == 'X' and labels[1] == 'O':
                target = 2
            elif labels[0] == 'O' and labels[1] == 'O':
                target = 3
        
        return start_image, end_image, target
        
    
    def __len__(self):
        return len(self.ox)
            

In [None]:
batch_size = 64
num_epoch = 300
lr = 0.005
model_name = "mobilenetv3_large_100"
num_cls = 3
xlsx_path = 'dataset.xlsx'
root_path = 'data'
device = 'cuda:1'
###################################

# define transform
transform = torchvision.transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Resize((360, 480)),
                            ])

# load dataset
dataset = CellDataset(xlsx_path, root_path, transform, num_cls)

# seperate dataset to train, test
s = int(len(dataset)*0.8)
indices = [i for i in range(len(dataset))]
train_data, test_data = Subset(dataset, indices[:s]), Subset(dataset, indices[s:])
data_loader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
data_loader_test = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=4)

# load model
model = timm.create_model(model_name, pretrained=True)
model.classifier = torch.nn.Linear(in_features=1024, out_features=4, bias=True)


# define loss function, optimizer, scheduler
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr) #weight_decay=0.0005
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)

# define device
if not device:
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

model.to(device)
criterion.to(device)

# train
min_false = float('inf')
for epoch in range(num_epoch):
    model.train()
    total_loss = 0
    for idx, data in enumerate(data_loader_train):
        start_image, end_image, target = data #nchw
        target = torch.tensor(target).to(device)
        optimizer.zero_grad()
        
        image = torch.cat([start_image, end_image], dim=3)
        output = model(image.to(device))
        loss = criterion(output, target)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss / len(train_data) * batch_size
    
    print(f"epoch:{epoch} loss: {total_loss}")
    
    # test & save
    if epoch % 10 == 0:
        model.eval()
        true = 0
        false = 0
        for idx, data in enumerate(data_loader_test):
            start_image, end_image, target = data #nchw
            image = torch.cat([start_image, end_image], dim=3)
            
            with torch.no_grad():
                pred = model(image.to(device))

            p, t = torch.argmax(pred,1).item(), target
            
            if p == t:
                true+=1
            else:
                false+=1

        print(f"true:{true} false:{false}")
        if false <= min_false:
            min_false = false
            torch.save(model.state_dict(), f"models/cell_{model_name}_3cls_best.pth")
    scheduler.step()