In [1]:
import sys
sys.path.append('..')
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import torchvision.transforms as tf

from src.models.swin.model import Swin
from src.models.train import Trainer
from src.features.segmentation.dataset import SegmentationDataset

from transformers import AutoImageProcessor, Swinv2Model

In [2]:
image_processor = AutoImageProcessor.from_pretrained("microsoft/swinv2-tiny-patch4-window16-256")
image_processor.do_resize = False
image_processor.do_rescale = False
model = Swinv2Model.from_pretrained("microsoft/swinv2-tiny-patch4-window16-256")



In [3]:
swin = Swin(
    net=model,
    image_processor=image_processor,
    device="mps"
)

In [4]:
device = torch.device("mps")
swin = swin.to(device)

In [5]:
for name, param in swin.named_parameters():
    if param.device != "mps:0":
        print(param.device)
        break

mps:0


In [6]:
transform = tf.Resize((256, 256))

In [7]:
train_dataset = SegmentationDataset(
    images_dir="../data/split/train/train_images/",
    masks_dir="../data/split/train/train_masks/",
    image_transform=transform,
    mask_transform=transform
)

val_dataset = SegmentationDataset(
    images_dir="../data/split/valid/valid_images/",
    masks_dir="../data/split/valid/valid_masks/",
    image_transform=transform,
    mask_transform=transform
)


train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=2, shuffle=False)


In [10]:
for image, target in train_dataloader:
    image = image.to(device)
    target = target.to(device)
    out = swin.val_on_batch(image, target)
    print(out[1])
    break

tensor([[[[0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5855, 0.5000],
          ...,
          [0.5000, 0.5000, 0.5000,  ..., 0.6402, 0.8030, 0.5000],
          [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000,  ..., 0.5509, 0.6979, 0.5000]]],


        [[[0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5363, 0.5000],
          [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.6248, 0.5000],
          ...,
          [0.5000, 0.5083, 0.5299,  ..., 0.6543, 0.8282, 0.5000],
          [0.5000, 0.5000, 0.5000,  ..., 0.5000, 0.5000, 0.5000],
          [0.5000, 0.5000, 0.5000,  ..., 0.5220, 0.6688, 0.5000]]]],
       device='mps:0', grad_fn=<SigmoidBackward0>)
