In [None]:
!pip install torchmetrics

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchmetrics
  Downloading torchmetrics-0.11.0-py3-none-any.whl (512 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m512.4/512.4 KB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchmetrics
Successfully installed torchmetrics-0.11.0


In [None]:
import torch
from torch import nn
#from torch.nn import functional as F
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
#from torch import optim
import os
import csv
#from PIL import Image
import warnings
warnings.simplefilter('ignore')
from torchvision.datasets import ImageFolder
from torchmetrics.functional.classification import multiclass_precision
from torchmetrics.functional.classification import multiclass_recall
from google.colab import drive

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean = (0.4800, 0.4609, 0.4225), std = (0.2588, 0.2551, 0.2769))
    ])

In [None]:
testset = ImageFolder(root="/content/drive/MyDrive/data_sun397_test",transform=transform)
trainset = ImageFolder(root="/content/drive/MyDrive/data_sun397",transform=transform)

#train_size = int(0.6 * len(dataset))
#test_size = int((len(dataset) - train_size) / 2)
#valid_size = int((len(dataset) - train_size) / 2)

test_size = int(len(testset) *0.6)
valid_size = int(len(testset) - test_size)

#train, test, valid = torch.utils.data.random_split(dataset, [train_size, test_size, valid_size])
test, valid = torch.utils.data.random_split(testset, [test_size, valid_size])

train = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle = True, num_workers = 32 )
test = torch.utils.data.DataLoader(test, batch_size=64, shuffle = True, num_workers = 32)
valid = torch.utils.data.DataLoader(valid, batch_size=64, shuffle = True, num_workers = 32 )

In [None]:
def batch_mean_and_sd(loader):
    
    cnt = 0
    fst_moment = torch.empty(3)
    snd_moment = torch.empty(3)

    for images, _ in loader:
        b, c, h, w = images.shape
        nb_pixels = b * h * w
        sum_ = torch.sum(images, dim=[0, 2, 3])
        sum_of_square = torch.sum(images ** 2,
                                  dim=[0, 2, 3])
        fst_moment = (cnt * fst_moment + sum_) / (
                      cnt + nb_pixels)
        snd_moment = (cnt * snd_moment + sum_of_square) / (
                            cnt + nb_pixels)
        cnt += nb_pixels

    mean, std = fst_moment, torch.sqrt(
      snd_moment - fst_moment ** 2)        
    return mean,std


In [None]:
mean, std = batch_mean_and_sd(train)
print("mean and std: \n", mean, std)

In [None]:
import torchvision.models
model = torchvision.models.resnet18(weights = None)

class Classifier(nn.Module):
  def __init__(self):
    super(Classifier, self).__init__()
    self.linear = torch.nn.Linear(512, 397)
    self.activation = torch.nn.Softmax(1)

  def forward(self, x):
    x = self.linear(x)
    x = self.activation(x)

    return x

model.fc = Classifier()

print(model)

In [None]:
model.load_state_dict(torch.load('/content/resnet_sun397 tot.txt'))

In [None]:
torch.manual_seed(1234) 

<torch._C.Generator at 0x7f87b6e35c70>

In [None]:
device = torch.device('cuda')
model = model.to(device)
#model.cuda(0)
loss_fn = nn.CrossEntropyLoss() #Select loss_function
optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4) #Selec


In [None]:
def get_metrics(model,loader):
    labels = []
    pred_labels = []
    labels = torch.cuda.FloatTensor(labels)
    pred_labels  = torch.cuda.FloatTensor(pred_labels)

    for img,label in loader:
        img,label = img.to(device),label.to(device)

        with torch.no_grad():
            logits = model(img)       
            pred = logits #.argmax(dim=1)

        labels = torch.cat((labels,label))
        pred_labels = torch.cat((pred_labels,pred))

    return multiclass_precision(pred_labels, labels, 397), multiclass_recall(pred_labels, labels, 397)


In [None]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience = 6)

In [None]:
best_epoch, best_prec = 0, 0, 0
valid_loss = 0
loss_history = []
prec_history = []
recall_history = []
for epoch in range(60):  
    model.train()
    for batch_num, (img, label) in enumerate(train):
        img,label = img.to(device),label.to(device)
        optimizer.zero_grad()
        logits = model(img)

        loss = loss_fn(logits, label)
        
        if (batch_num + 1) % 50  == 0:
            print('{}th batch of the {}th epoch, loss {}'.format(batch_num + 1, epoch + 1,
                                                                                        loss.item()))
            loss_history.append(loss.item())

        loss.backward()
        optimizer.step()

    model.eval()
    prec, rec = get_metrics(model, valid)

    prec_history.append(prec.item())
    recall_history.append(rec.item())
    
    for img, label in valid:
        if torch.cuda.is_available():
            img,label = img.to(device),label.to(device)
  
        target = model(img)
        valid_step_loss = loss_fn(target, label)
        valid_loss += valid_step_loss.item() * img.size(0)


    if prec > best_prec:
         print('Precision: {}'.format(prec))
         print('Recall: {}'.format(rec))

         best_epoch = epoch
         best_prec = prec
         torch.save(model.state_dict(), 'resnet_sun397 ep ' + str(epoch+1)+'.txt')

    print(f'Epoch {epoch}\t \
            Validation Loss:{valid_loss/len(valid)}')
    print('LR: ',(optimizer.param_groups[0]['lr']))
    scheduler.step(valid_loss/len(valid))
    valid_loss = 0

print('best_prec:{},best_epoch:{}'.format(best_prec, best_epoch))

In [None]:
print(get_metrics(model,test))

(tensor(0.3542, device='cuda:0'), tensor(0.3532, device='cuda:0'))


In [None]:
torch.save(model.state_dict(), 'resnet_sun397.txt')