In [1]:
import torch
import numpy as np
import os
import resnet
import pytorch_lightning as pl
import tqdm
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import copy
from torch.optim.lr_scheduler import ReduceLROnPlateau
# device = torch.device('cpu')
print(device)

cuda


In [2]:
INIT_LR = 1e-4
BATCH_SIZE = 15 #100
EPOCHS = 10 #10
# define the train and val splits
TRAIN_SPLIT = 0.75
VAL_SPLIT = 1 - TRAIN_SPLIT

In [3]:
model_out_name = "augrottrans"
classes = ['no','sphere','vort']

In [4]:
class LensTrainDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, sub_class_dir, u_length, classes):
        self.img_dir = img_dir
        self.classes = classes
        self.sub_class_dir = sub_class_dir
        self.img_labels = list(range(len(self.classes)))
        self.u_length = u_length
        self.sub_classes = {'rotation':7,'translation':4}
        # self.sub_classes = {'translation':4}
        self.n_subclasses = 0
        self.sub_class_identifier = []
        counter = 1
        for label in self.sub_classes:
            self.n_subclasses += self.sub_classes[label]
            self.sub_class_identifier.append(counter+self.sub_classes[label])
            counter += self.sub_classes[label]

    def __len__(self):
        additional_length = 0
        for label in self.sub_classes.keys():
            additional_length += self.sub_classes[label]*len(self.img_labels)*self.u_length
        return len(self.img_labels)*self.u_length+additional_length

    def __getitem__(self, idx):
        item_sub_class = idx//(self.u_length*len(self.classes))
        
        for i in range(len(self.sub_class_identifier)):
            if item_sub_class < self.sub_class_identifier[i]:
                sub_class_directory = list(self.sub_classes.keys())[i]
                break
        if item_sub_class == 0: sub_class_directory = 'None'
        item_class = (idx - item_sub_class*len(self.classes)*self.u_length)//self.u_length
        if item_sub_class == 0:
            img_path = os.path.join(self.img_dir, "%s/%s.npy"%(self.classes[item_class],(idx%self.u_length)+1))
            image = torch.Tensor(np.load(img_path))
        else:
            img_path = os.path.join(self.sub_class_dir, "%s/%s/%s_%d.npy"%(self.classes[item_class],sub_class_directory,(idx%self.u_length)+1,item_sub_class))
            img_path = os.path.join(self.img_dir, "%s/%s.npy"%(self.classes[item_class],(idx%self.u_length)+1))
            try:
                image = torch.Tensor(np.array(np.load(img_path)))
            except EOFError:
                print(img_path)
                image = torch.Tensor(np.array([np.load(img_path)]))

        label = self.img_labels[item_class]
        return image, label
    
class LensTrainOriginalDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, sub_class_dir, u_length, classes):
        self.img_dir = img_dir
        self.classes = classes
        # self.sub_class_dir = sub_class_dir
        self.img_labels = list(range(len(self.classes)))
        self.u_length = u_length
        # self.sub_classes = {'rotation':7}

    def __len__(self):
        # additional_length = 0
        # for label in self.sub_classes.keys():
        #     additional_length += self.sub_classes[label]*len(self.img_labels)*self.u_length
        return len(self.img_labels)*self.u_length

    def __getitem__(self, idx):
        item_sub_class = idx//(self.u_length*len(self.classes))
        item_class = (idx - item_sub_class*len(self.classes)*self.u_length)//self.u_length
        # if item_sub_class == 0:
        img_path = os.path.join(self.img_dir, "%s/%s.npy"%(self.classes[item_class],(idx%self.u_length)+1))
        image = torch.Tensor(np.load(img_path))
        # else:
        #     img_path = os.path.join(self.sub_class_dir, "%s/%s/%s_%d.npy"%(self.classes[item_class],'rotation',(idx%self.u_length)+1,item_sub_class))
        #     img_path = os.path.join(self.img_dir, "%s/%s.npy"%(self.classes[item_class],(idx%self.u_length)+1))
        #     try:
        #         image = torch.Tensor(np.array(np.load(img_path)))
        #     except EOFError:
        #         print(img_path)
        #         image = torch.Tensor(np.array([np.load(img_path)]))

        label = self.img_labels[item_class]
        return image, label
    
class LensTestDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, u_length, classes):
        self.img_dir = img_dir
        self.classes = classes
        self.img_labels = list(range(len(self.classes)))
        self.u_length = u_length

    def __len__(self):
        return len(self.img_labels)*self.u_length

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, "%s/%s.npy"%(self.classes[int(idx/self.u_length)],idx+1-int(idx/self.u_length)*self.u_length))
        image = torch.Tensor(np.load(img_path))
        label = self.img_labels[int(idx/self.u_length)]
        # if self.transform:
        #     image = self.transform(image)
        # if self.target_transform:
        #     label = self.target_transform(label)
        return image, label

In [5]:
class ResNetDeepLense(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = resnet.ResNet50(img_channels=1,num_classes=3)
        self.loss = torch.nn.CrossEntropyLoss()
    def forward(self, x):
        return self.model(x)
    
    # @auto_move_data
    def training_step(self, batch, batch_no):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        return loss
    
    def validation_step(self, batch, batch_no):
        x, y = batch
        logits = self(x)
        loss = self.loss(logits, y)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=INIT_LR)

In [6]:
train_dataset = LensTrainDataset('../dataset/train/', '/media/anirudh/Extreme SSD/DeepLense/augmented_data', 10000, classes)
test_dataset = LensTestDataset('../dataset/val/', 2500, classes)
labels_length = len(train_dataset.img_labels)

numTrainSamples = int(len(train_dataset) * TRAIN_SPLIT)
numValSamples = int(len(train_dataset) * VAL_SPLIT)
print("[INFO] %d samples to train, and %d to validate over %d epochs"%(numTrainSamples,numValSamples,EPOCHS))
(train_dataset, val_dataset) = torch.utils.data.random_split(train_dataset,
	[numTrainSamples, numValSamples])

train_dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=15)
val_dataloader = torch.utils.data.DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=15)
test_dataloader = torch.utils.data.DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=15)

trainSteps = len(train_dataloader.dataset) // BATCH_SIZE
valSteps = len(val_dataloader.dataset) // BATCH_SIZE

print("[INFO] initializing the ResNet model...")
model = ResNetDeepLense().to(device)
# model = lenet.LeNet(1,3).to(device)

batches_per_epoch = int(len(train_dataset)/(EPOCHS*BATCH_SIZE))
loss_fn = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(),lr=INIT_LR)
best_acc = -np.inf
best_weights = None

scheduler = ReduceLROnPlateau(opt,'min',patience=2,factor=0.1,verbose=True)
history = {'val_loss':[],'val_acc':[]}
for epoch in range(EPOCHS):
    epoch_loss = []
    epoch_acc = []
    with tqdm.trange(batches_per_epoch, unit="batch", mininterval=0) as bar:
        bar.set_description(f"Training epoch {epoch}")
        model.train()
        for (images, labels) in train_dataloader:
            (images, labels) = (images.to(device), labels.to(device))
            out_images = model(images)
            loss = loss_fn(out_images, labels)
            acc = (torch.argmax(out_images, 1) == labels).float().mean()
            opt.zero_grad()
            loss.backward()
            opt.step()
            bar.set_postfix(
                loss=float(loss),
                acc=float(acc)
            )
        model.eval()
        bar.set_description(f"Testing epoch {epoch}")
        for (images, labels) in test_dataloader:
            (images, labels) = (images.to(device), labels.to(device))
            out_images = model(images)
            val_loss = loss_fn(out_images, labels)
            acc = (torch.argmax(out_images, 1) == labels).float().mean()
            epoch_loss.append(float(val_loss))
            epoch_acc.append(float(acc))
            bar.set_postfix(
                val_loss=float(val_loss),
                acc=float(acc)
            )
    history['val_loss'].append(np.mean(epoch_loss))
    history['val_acc'].append(np.mean(epoch_acc))
    if np.mean(epoch_acc) > best_acc:
        best_acc = np.mean(epoch_acc)
        best_weights = copy.deepcopy(model.state_dict())
    print(f"Epoch {epoch} validation: Cross-entropy={np.mean(epoch_loss)}, Accuracy={np.mean(epoch_acc)}, LR={opt.param_groups[0]['lr']}")
    scheduler.step(np.mean(epoch_loss))

torch.save(model, 'saved%s.pth'%model_out_name)


[INFO] 270000 samples to train, and 90000 to validate over 10 epochs
[INFO] initializing the ResNet model...


Training epoch 0:   0%|          | 0/1800 [00:05<?, ?batch/s, acc=0.133, loss=1.2]  

In [None]:
# def get_prediction(x, model: pl.LightningModule):
#   model.freeze() # prepares model for predicting
#   probabilities = torch.softmax(model(x), dim=1)
#   predicted_class = torch.argmax(probabilities, dim=1)
#   return predicted_class, probabilities

# inference_model = ResNetDeepLense.load_from_checkpoint("resnet50_deeplense.pt", map_location="cuda").to(device)
# true_y, pred_y, prob_l = [], [], []
# for batch in tqdm(iter(test_dataloader), total=len(test_dataloader)):
#   x, y = batch
#   (x, y) = (x.to(device), y.to(device))
#   true_y.extend(y)
#   preds, probs = get_prediction(x, inference_model)
#   pred_y.extend(preds)
#   prob_l.extend(probs)

In [None]:
# from sklearn.metrics import classification_report
# pred_y = torch.Tensor(pred_y)
# true_y = torch.tensor(true_y)
# print(classification_report(true_y, pred_y, digits=3))