In [1]:
import torch
import numpy as np
import torchvision
import pandas as pd
import bisect
import os
import glob
from PIL import Image
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.models as models
import torch.nn as nn
from tqdm import tqdm
import time
from heapq import heappop, heappush, heapify
from collections import defaultdict


import matplotlib.pyplot as plt

In [2]:
from src.utils.network import SupConResNet_v
from src.utils.loss import SupConLoss
from src.utils.stem_dataset import STEMDataset, ToTensor, collate_fn, RandomCrop

## Parameters you may need to change

In [3]:
csv_file = "<change_the_path>/atomagined/key.csv" # path to the annotation file
root_dir = "<change_the_path>/atomagined/general/png/" # dir of the STEM images
csv_file = "/home/weixin/Documents/data/MaterialEyes/atomagined/key.csv"
root_dir = "/home/weixin/Documents/data/MaterialEyes/atomagined/general/png/"
batch_size = 20 # parameter for model training, depends on the GPU MEMORY
image_size = 200 # parameter for input enhancement for training (crop_size)
device="cuda:0" # parameter for gpu id selectoin (in case of more than one gpu)
max_epoch = 1000 # the maximum epochs for model training

In [4]:
start = time.time()
dataset = STEMDataset(csv_file, root_dir, transform=transforms.Compose([RandomCrop(image_size), ToTensor()]))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True)
print("loading dataset. Time elapsed %.2f"%(time.time()-start))
len(dataset), len(dataloader)

loading dataset. Time elapsed 131.17


(5998, 299)

In [5]:

"""
net parameters:
name: may select different resnet as the backbone [resnet18, resnet34, resnet50, resnet101]
head: the projection head for the contrastive learning [mlp, linear, none]
cls: the classification task [icsd, symtable]
feat_dim: the dimensionality of the metric space for contrastive loss 
computing, only works if the head is not none
"""

net = SupConResNet_v(name='resnet18', head="mlp", cls="icsd", feat_dim=128) # mlp head + icsd



# in case you have pretrained model for model initializatoin 
#net.load_state_dict(torch.load("resnet18_cls_none_sim_none_epoch181.pt", map_location="cpu"), strict=False)

# Train

In [6]:
net.to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) # optimizer, [Adam, SGD]
criterion_cls = torch.nn.CrossEntropyLoss() # cross entropy loss for classification task
criterion_sim = SupConLoss(temperature=0.1) # loss for contrastive learning
net.train()

for epoch in range(max_epoch):
    with tqdm(dataloader, unit="batch") as tepoch:
        for data in tepoch:
            """
            data["imgs"]: batch of image data
            data["labels"]: [ref_ids, icsd_ids, symtable_ids]
            """
            optimizer.zero_grad()
            imgs, labels = data["imgs"].to(device), data["labels"][1].to(torch.int64).to(device)
            pred, cls_pred = net(imgs)
            f1, f2 = torch.split(pred, [batch_size, batch_size], dim=0)
            pred = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
            cls_loss = criterion_cls(cls_pred, labels)
            sim_loss = criterion_sim(pred)
            loss = cls_loss + sim_loss
            loss.backward()
            optimizer.step()
            tepoch.set_postfix(loss=float(loss), cls_loss=float(cls_loss), sim_loss=float(sim_loss))
    # save model for every epoch
    torch.save(net.state_dict(), "./resnet18_epoch%.3d.pt"%(epoch))
    

  5%|▌         | 15/299 [00:21<06:52,  1.45s/batch, cls_loss=9.06, loss=14.2, sim_loss=5.16]


KeyboardInterrupt: 