In [1]:
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader
import os
import torch 

#data_dir = "/home/ubuntu/MySAM/sam_data/sa_000020/image_dir"
data_dir = "/home/ubuntu/MySAM/sam_data/mini_image_dir/image_dir" #fix me
image_path_iter = DataLoader(os.listdir(data_dir), batch_size=2, shuffle=True)

import cv2
def load_image(image_path, root_dir):
    #image_path = "/home/ubuntu/sam_data/sa_000020/mini_image_dir/train/sam:1.jpg"
    image = cv2.imread(''.join([root_dir, '/', image_path]))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return torch.from_numpy(image).permute(2, 0, 1)

def load_embedding(image_path, embedding_dir):
    return torch.load(''.join([embedding_dir, '/', image_path.split('.')[0], '.pth'])) 

paths = data_dir.split('/')
paths[-1] = 'embedding_dir'
embedding_dir = "/".join(paths)
if not os.path.exists(embedding_dir):
    raise Exception("Directory not exist!")
    

In [2]:
from student_modeling import TinyViT
student_image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000,
                embed_dims=[64, 128, 160, 320],
                depths=[2, 2, 6, 2],
                num_heads=[2, 4, 5, 10],
                window_sizes=[7, 7, 14, 7],
                mlp_ratio=4.,
                drop_rate=0.,
                drop_path_rate=0.0,
                use_checkpoint=False,
                mbconv_expand_ratio=4.0,
                local_conv_size=3,
                layer_lr_decay=0.8
            )

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
_ = student_image_encoder.to(device)

In [3]:
from torch.nn import functional as F
pixel_mean = torch.tensor([[[123.6750]], [[116.2800]], [[103.5300]]]).to(device)
pixel_std = torch.tensor([[[58.3950]], [[57.1200]], [[57.3750]]]).to(device)

def preprocess(x, pixel_mean, pixel_std):
    x = (x - pixel_mean) / pixel_std
    # Pad
    h, w = x.shape[-2:]
    padh = 1024 - h
    padw = 1024 - w
    x = F.pad(x, (0, padw, 0, padh))
    return x

In [4]:
num_epochs = 1
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(student_image_encoder.parameters(), lr=0.001)
for epoch in range(num_epochs):
    student_image_encoder.train()
    train_loss = 0.0
    for batch_path in image_path_iter:
        input_images = torch.stack([preprocess(load_image(path, data_dir).to(device), pixel_mean, pixel_std) for path in batch_path])
        embedding = student_image_encoder(input_images)
        target_embedding = torch.stack([load_embedding(path, embedding_dir).to(device) for path in batch_path])
        loss = criterion(embedding, target_embedding)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * input_images.shape[0]
    print('epoch %d, loss %.4f' % (epoch + 1, train_loss / len(image_path_iter.dataset)))

epoch 1, loss 1.0094
