In [1]:
from transformers import AutoImageProcessor, ViTModel
import torch
import os
from data import NIRCamDataset_ViT
from glob import glob
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
#https://huggingface.co/blog/fine-tune-vit

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset_path = '/data/scratch/bariskurtkaya/dataset/NIRCAM/1386/injections_24102023/test'
dataset = glob(os.path.join(dataset_path, '*.npy'))

image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

In [4]:
nircam = NIRCamDataset_ViT(dataset, device)
nircam_loader = DataLoader(nircam, batch_size=256, shuffle=True)

In [5]:
class ViT(torch.nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.linear = torch.nn.Linear(768, 256)
        self.linear2 = torch.nn.Linear(256, 2)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        x = self.model(x).pooler_output
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

In [6]:
model = ViT(model).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Do not train base model
for param in model.model.parameters():
    param.requires_grad = False

In [7]:
for epoch in range(10):
    model.train()
    for idx, (samples, labels) in tqdm(enumerate(nircam_loader)):
        inputs = image_processor(samples, return_tensors="pt").to(device)
        outputs = model(inputs.pixel_values)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if idx % 100 == 0:
            print(f"Epoch {epoch} Batch {idx} Loss {loss.item()}")

    lr_scheduler.step()
    model.eval()
    correct = 0
    total = 0
    for idx, (samples, labels) in enumerate(nircam_loader):
        inputs = image_processor(samples, return_tensors="pt").to(device)
        outputs = model(inputs.pixel_values)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    print(f"Epoch {epoch} Accuracy {correct/total}")

torch.save(model.state_dict(), 'model.pth')

0it [00:00, ?it/s]

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.
1it [00:05,  5.01s/it]

Epoch 0 Batch 0 Loss 0.6931940317153931


86it [03:10,  2.21s/it]


Epoch 0 Accuracy 0.5062732975725066


1it [00:02,  2.18s/it]

Epoch 1 Batch 0 Loss 0.6923569440841675


86it [03:06,  2.17s/it]


Epoch 1 Accuracy 0.4937267024274934


1it [00:02,  2.17s/it]

Epoch 2 Batch 0 Loss 0.6949406266212463


86it [03:06,  2.17s/it]


Epoch 2 Accuracy 0.5062732975725066


1it [00:02,  2.16s/it]

Epoch 3 Batch 0 Loss 0.6960892677307129


86it [03:04,  2.15s/it]


Epoch 3 Accuracy 0.5062732975725066


1it [00:02,  2.30s/it]

Epoch 4 Batch 0 Loss 0.6937497854232788


86it [03:06,  2.17s/it]


In [None]:
inputs = image_processor(sample[0], return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)

last_hidden_states = outputs.last_hidden_state

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.


In [13]:
last_hidden_states.shape

torch.Size([4, 197, 768])