In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import glob
from PIL import Image
import torchvision
import random
import numpy

In [8]:
random.seed(0)
stroke = glob.glob("./input/stroke_data/*")
no_stroke = glob.glob("./input/noStroke_data/*")

print(stroke)

ratio = 0.2
stroke_val = random.sample(stroke, int(ratio*len(stroke)))
no_stroke_val = random.sample(no_stroke, int(ratio*len(no_stroke)))

stroke_train = list(set(stroke).difference(set(stroke_val)))
no_stroke_train = list(set(no_stroke).difference(set(no_stroke_val)))

stroke_train = [(i, 0) for i in stroke_train]
no_stroke_train = [(i, 1) for i in no_stroke_train]

stroke_val = [(i, 0) for i in stroke_val]
no_stroke_val = [(i, 1) for i in no_stroke_val]

['./input/stroke_data/aug_0_6814.jpg', './input/stroke_data/aug_0_3944.jpg', './input/stroke_data/aug_0_7263.jpg', './input/stroke_data/aug_0_3777.jpg', './input/stroke_data/aug_0_1606.jpg', './input/stroke_data/aug_0_8144.jpg', './input/stroke_data/aug_0_5448.jpg', './input/stroke_data/aug_0_3987.jpg', './input/stroke_data/aug_0_665.jpg', './input/stroke_data/aug_0_4742.jpg', './input/stroke_data/aug_0_6155.jpg', './input/stroke_data/aug_0_5879.jpg', './input/stroke_data/aug_0_7934.jpg', './input/stroke_data/aug_0_5110.jpg', './input/stroke_data/aug_0_8434.jpg', './input/stroke_data/aug_0_5676.jpg', './input/stroke_data/aug_0_2125.jpg', './input/stroke_data/aug_0_301.jpg', './input/stroke_data/aug_0_3549.jpg', './input/stroke_data/aug_0_9689.jpg', './input/stroke_data/aug_0_5925.jpg', './input/stroke_data/aug_0_8540.jpg', './input/stroke_data/aug_0_7673.jpg', './input/stroke_data/aug_0_9138.jpg', './input/stroke_data/aug_0_9662.jpg', './input/stroke_data/aug_0_2737.jpg', './input/stro

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class dataset(torch.utils.data.Dataset):
    
    def __init__(self, img_list, test=False):
        self.img_list = img_list
        self.test = test
        self.augmentations = None
        
    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        img_dir, label = self.img_list[idx]
        img = Image.open(img_dir)
        img = img.resize((256, 256), Image.LINEAR)
        if not self.test and self.augmentations:
            img = self.augmentations(img)
        img = torchvision.transforms.ToTensor()(img)
        return img, label
    
model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.DEFAULT)
model.fc = nn.Sequential(
    nn.Linear(in_features=2048, out_features=1,  bias=True)
)

train_dataset = dataset(stroke_train + no_stroke_train, test=False)
val_dataset = dataset(stroke_val + no_stroke_val, test=True)

trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
valloader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=False, drop_last=False, num_workers=2, pin_memory=True)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.BCEWithLogitsLoss()
model.to(device)

for epoch in range(20):
    for iteration, (imgs, lbls) in enumerate(trainloader):
        optimizer.zero_grad()
        model.train()
        imgs = imgs.to(device)
        lbls = lbls.to(device).float()
        preds = model(imgs).squeeze()
        loss = criterion(preds, lbls)
        loss.backward()
        optimizer.step()
        if iteration % 10 == 0:
            print(f'Epoch: {epoch} Iteration: {iteration} Loss: {loss}')
    model.eval()
    confusion_matrix = torch.zeros((2, 2))
    for imgs, lbls in valloader:
        imgs = imgs.to(device)
        lbls = lbls.to(device)
        
        with torch.no_grad():
            outputs = model(imgs).squeeze()
        outputs = F.sigmoid(outputs)
        preds = torch.where(outputs > 0.5, 1, 0)
        for t, p in zip(lbls.view(-1), preds.view(-1)):
            confusion_matrix[t.long(), p.long()] += 1
        per_class_acc = confusion_matrix.diag()/confusion_matrix.sum(1)
    print(f'Stroke accuracy: {per_class_acc[0].item()}, Non stroke accuracy: {per_class_acc[1].item()}')
    torch.save({"model_state_dict": model.state_dict(), 
                        "iteration": iteration,
                        },str(iteration) + ".pkl")

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /Users/audreylai/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100.0%


ValueError: num_samples should be a positive integer value, but got num_samples=0