In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import torchvision
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler, WeightedRandomSampler
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [2]:
transform = transforms.Compose(
    [transforms.Grayscale(num_output_channels=3),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.1, hue=0.1),transforms.RandomRotation(90, resample=False, expand=False, center=None, fill=None),transforms.Resize((224,224)),
     transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
     ])
trf=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor()
])
trainset = ImageFolder('train_3',transform=transform)
testset = ImageFolder('test_3',transform=trf)
c=trainset.class_to_idx
idx2class = {v: k for k, v in c.items()}
idx2class

{0: 'Acrobeloides',
 1: 'Aphelenchoides',
 2: 'Aporcelaimus',
 3: 'Axonchium',
 4: 'Discolimus',
 5: 'Ditylenchus',
 6: 'Eudorylaimus',
 7: 'Helicotylenchus',
 8: 'Mesodorylaimus',
 9: 'Miconchus',
 10: 'Mylonchulus',
 11: 'Panagrolaimus',
 12: 'Pratylenchus',
 13: 'Pristionchus',
 14: 'Rhbiditis',
 15: 'Xenocriconema'}

In [3]:
print(len(trainset))
print(len(testset))

4217
505


In [4]:
def get_class_distribution(dataset_obj):
    count_dict = {k:0 for k,v in dataset_obj.class_to_idx.items()}
    
    for element in dataset_obj:
        y_lbl = element[1]
        y_lbl = idx2class[y_lbl]
        count_dict[y_lbl] += 1
            
    return count_dict
print("Distribution of classes: \n", get_class_distribution(trainset))

Distribution of classes: 
 {'Acrobeloides': 297, 'Aphelenchoides': 278, 'Aporcelaimus': 309, 'Axonchium': 272, 'Discolimus': 285, 'Ditylenchus': 262, 'Eudorylaimus': 276, 'Helicotylenchus': 248, 'Mesodorylaimus': 289, 'Miconchus': 280, 'Mylonchulus': 220, 'Panagrolaimus': 261, 'Pratylenchus': 251, 'Pristionchus': 193, 'Rhbiditis': 264, 'Xenocriconema': 232}


In [5]:
target_list = torch.tensor(trainset.targets)
target_list = target_list[torch.randperm(len(target_list))]

class_count = [i for i in get_class_distribution(trainset).values()]
class_weights = 1./torch.tensor(class_count, dtype=torch.float) 
class_weights

tensor([0.0034, 0.0036, 0.0032, 0.0037, 0.0035, 0.0038, 0.0036, 0.0040, 0.0035,
        0.0036, 0.0045, 0.0038, 0.0040, 0.0052, 0.0038, 0.0043])

In [6]:
class_weights_all = class_weights[target_list]
class_weights_all

tensor([0.0034, 0.0043, 0.0032,  ..., 0.0036, 0.0045, 0.0038])

In [7]:
len(class_weights_all)

4217

In [8]:
weighted_sampler = WeightedRandomSampler(
    weights=class_weights_all,
    num_samples=len(class_weights_all),
    replacement=True
)

In [9]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=False, sampler=weighted_sampler)

testloader = torch.utils.data.DataLoader(testset, batch_size=32,
                                         shuffle=True, num_workers=0)

classes = ('Helicotylenchus','Xenocriconema','Mylonchulus','Ditylenchus','Panagrolaimus','Rhbiditis','Pratylenchus','Acrobeloides','Pristionchus','Aphelenchoides','Axonchium','Aporcelaimus','Discolimus','Eudorylaimus','Mesodorylaimus','Miconchus')

In [12]:
from torchvision import models
model=models.resnet18(pretrained=True)
for param in list(model.parameters())[:-1]:
    param.requires_grad=False
model=model.to(device)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [13]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9,weight_decay=0)

In [None]:
train_loss=[]

train_acc=[]

test_loss=[]

test_acc=[]

num_epoch=3

for epoch in range(num_epoch):  # loop over the dataset multiple times
    
    running_train_loss = 0.0
    running_test_loss = 0.0
    running_acc=0.0
    running_test_acc=0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # print statistics
        running_train_loss += loss.item()
        _,predicted = torch.max(outputs,1)
        num_correct=(predicted==labels).sum()
        running_acc +=num_correct.item()/len(trainset)
        if i % 20 == 19:    # print every 20 mini-batches
            print('[%d, %5d] train_loss: %.3f' %
                  (epoch + 1, i + 1, running_train_loss / 20))
            print('[%d, %5d] running_acc: %.3f' %
                  (epoch + 1, i + 1, running_acc))
            train_loss.append(running_train_loss/20)
            train_acc.append(running_acc)
            
    for i,data in enumerate(testloader, 0):
        inputs, labels = data
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        running_test_loss += loss.item()
        _,predicted = torch.max(outputs,1)
        num_correct=(predicted==labels).sum()
        running_test_acc +=num_correct.item()/len(testset)
        if i % 5 == 4:    # print every 5 mini-batches
            print('[%d, %5d] test_loss: %.3f' %
                  (epoch + 1, i + 1, running_test_loss / 5))
            print('[%d, %5d] test_acc: %.3f' %
                  (epoch + 1, i + 1, running_test_acc))
            test_loss.append(running_test_loss/5)
            test_acc.append(running_test_acc)

print('Finished Training')

correct = 0
total = 0
with torch.no_grad():
    for data in trainloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the traning images: %d %%' % (
    100 * correct / total))

In [None]:
PATH = 'model_was.pt'
torch.save(model.state_dict(), PATH)

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test images: %d %%' % (
    100 * correct / total))

In [None]:
train_loss.append(running_train_loss)
len(train_loss)
print(train_loss)

In [None]:
train_acc.append(running_acc)
print(train_acc)

In [101]:
class_correct = list(0. for i in range(16))
class_total = list(0. for i in range(16))
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        c = (predicted == labels).squeeze()
        for i in range(9):
            label = labels[i]
            class_correct[label] += c[i].item()
            class_total[label] += 1


for i in range(16):
    print('Accuracy of %5s : %2d %%' % (
        classes[i], 100 * class_correct[i] / class_total[i]))

Accuracy of Helicotylenchus :  0 %
Accuracy of Xenocriconema : 100 %
Accuracy of Mylonchulus :  0 %
