In [None]:
import os, warnings
from pathlib import Path
from glob import glob

import numpy as np 
import scipy as sp
import matplotlib as mpl
import matplotlib.pyplot as plt

#import torch
import skimage as ski
import sklearn as skl

from plantcv import plantcv as pcv
import flyr

In [None]:
# writing a class
class FlirDataset(torch.utils.data.Dataset):
    ""
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, k):
        return self.samples[k]
    def __init__(self, path): 
        from glob import glob
        from pathlib import Path
        path = Path(path)
        annot_pattern = str(path / "training" / "annotated" / "*.png")
        annot_filenames = glob(annot_pattern)
        annot_ims = {
            Path(filename).name[:-4]: plt.imread(filename)
            for filename in annot_filenames}
        flir_pattern = str(path / "*" / "thermal" / "*.jpg")
        flir_filenames = {
            Path(file).name[:-4]: file
            for file in glob(flir_pattern)}
        self.names = list(annot_ims.keys())
        flir_ims = {
            key: self.load_flir(flir_filenames[key])
            for key in self.names}
        self.images = flir_ims
        self.annots = annot_ims
        #for (key, im, annot) in zip(ims.keys(), self.images, self.annot):
        #    if im[1].shape[:2] != annot.shape[:2]:
        #        print('!!!', key, im[0].shape, im[1].shape, annot.shape)
        sdata = {}
        for key in self.names:
            (thr_im, opt_im) = self.images[key]
            ann_im = self.annots[key]
            for (rno, rowidx) in enumerate(range(0, opt_im.shape[0], 45)):
                if rowidx + 45 >= opt_im.shape[0]:
                    continue
                for (cno, colidx) in enumerate(range(0, opt_im.shape[1], 45)):
                    if colidx + 45 >= opt_im.shape[1]:
                        continue
                    # Get the subimage from the optical and annotation images:
                    opt_sub = opt_im[rowidx:rowidx + 45, colidx:colidx + 45]
                    thr_sub = opt_im[rowidx:rowidx + 45, colidx:colidx + 45]
                    ann_sub = ann_im[rowidx:rowidx + 45, colidx:colidx + 45]
                    tup = (rowidx, colidx, opt_sub, ann_sub, thr_sub)
                    sdata[key, rno, cno] = tup
        self.sample_data = sdata
        self.masks = {}
        self.samples = []
        for ((k,rno,cno), tup) in sdata.items():
            (rowidx, colidx, opt_sub, ann_sub, thr_sub) = tup
            plant_pixels = np.all(ann_sub == [1, 0, 0, 1], axis=2)
            self.masks[k, rno, cno] = plant_pixels
            opt_for_torch = torch.permute(
                torch.tensor(opt_sub, dtype=torch.float) / 255,
                (2, 0, 1))
            ann_frac = 1 - np.sum(plant_pixels) / plant_pixels.size
            #ann_frac = torch.tensor(
            #    round(ann_frac * 999),
            #    dtype=torch.long)
            ann_frac = torch.tensor(ann_frac, dtype=torch.float)
            sample = (opt_for_torch, ann_frac)
            self.samples.append(sample)
    def load_flir(self, filename, thermal_unit='celsius'):
        """Loads and returns the portion of a FLIR image file that contains both
        optical and thermal data.
        
        Parameters
        ----------
        filename : pathlike
            A ``pathname.Path`` object or a string representing the filename of
            image that is to be loaded.
        thermal_unit : {'celsius' | 'kelvin' | 'fahrenheit'}, optional
            What temperature units to return; the default is ``'celsius'``.
            
        Returns
        -------
        optical_image : numpy.ndarray
            An image-array with shape ``(rows, cols, 3)`` containing the RGB
            optical of the visual FLIR image.
        thermal_image : numpy.ndarray
            An image-array with shape ``(rows, cols)`` containing the thermal
            values in Celsius.
        """
        from os import fspath
        from PIL import Image
        import flyr
        # Make sure we have a path:
        filename = fspath(filename)
        # Import the raw image data:
        flir_image = flyr.unpack(filename)
        # Extract the optical and thermal data:
        opt = flir_image.optical
        #plt.imshow(opt)
        thr = getattr(flir_image, thermal_unit)
        pip = flir_image.pip_info
        x0 = pip.offset_x
        y0 = pip.offset_y
        ratio = pip.real_to_ir
        ratio = opt.shape[0] / thr.shape[0] / ratio
        # Resize the thermal image to match the optical image in resolution:
        (opt_rs, opt_cs, _) = opt.shape
        (thr_rs, thr_cs) = np.round(np.array(thr.shape) * ratio).astype(int)
        thr = np.array(Image.fromarray(thr).resize([thr_cs, thr_rs]))
        #plt.imshow(thr)
        x0 = round(opt_cs // 2 - thr_cs // 2 + x0)
        y0 = round(opt_rs // 2 - thr_rs // 2 + y0)
        return (thr, opt[y0:y0+thr_rs, x0:x0+thr_cs, :])

In [None]:
train_dset = FlirDataset(Path.home() / 'Desktop' / 'monthly images')

In [None]:
class FlirResNet(torch.nn.Module):
    def __init__(self, resnet='resnet18', weights='IMAGENET1K_V1'):
        super().__init__()
        self.resnet = torch.hub.load(
            'pytorch/vision:v0.13.0', resnet, 
            weights=weights)
        self.linear = torch.nn.Linear(1000, 1)
    def forward(self, inputs):
        return self.linear(self.resnet(inputs)).flatten()

In [None]:
import torch
import torchvision

# Hyperparameters:
n_epochs = 8  # 1 epoch == show all training data to the model once.
lr = 0.0005   # We use a fairly low learning rate.
batch_size = len(train_ds)  # How many images in one training batch.

# Make the model:
model = FlirResNet()

# Make the optimizer and LR-manager:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
steplr = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=1,
    gamma=0.65)

# Declare our loss function:
loss_fn = torch.nn.L1Loss()

# Make the dataloaders:
train_dloader = torch.utils.data.DataLoader(train_dset, batch_size=batch_size, shuffle=True)
#test_dloader = torch.utils.data.DataLoader(test_dset, batch_size=batch_size, shuffle=True)

# Now we start the optimization loop:
for epoch_num in range(n_epochs):
    # Put the model in train mode:
    model.train()
    # In each epoch, we go through each training sample once; the dataloader
    # gives these to us in batches:
    total_train_loss = 0
    for (inputs, targets) in train_dloader:
        # We're starting a new step, so we reset the gradients.
        optimizer.zero_grad()
        # Calculate the model prediction for these inputs.
        preds = model(inputs)
        # Calculate the loss between the prediction and the actual outputs.
        train_loss = loss_fn(torch.sigmoid(preds), targets)
        # Have PyTorch backward-propagate the gradients.
        train_loss.backward()
        # Have the optimizer take a step:
        optimizer.step()
        # Add up the total training loss:
        total_train_loss = total_train_loss + train_loss*len(targets)
    # LR Scheduler step:
    steplr.step()
    mean_train_loss = total_train_loss.detach() / len(train_dset)
    # Now that we've finished training, put the model back in evaluation mode.
    #model.eval()
    ## Evaluate the model using the test data.
    #total_test_loss = 0
    #for (inputs, targets) in test_dloader:
    #    preds = model(inputs)
    #    test_loss = loss_fn(preds, targets)
    #    total_test_loss = total_test_loss + train_loss
    #mean_test_loss = total_test_loss.detach() / len(test_dset)
    # Print something about this step:
    print(f"Epoch {epoch_num:2d} loss: {mean_train_loss:6.3f}")
# After the optimizer has run, print out what it's found:
print("Final result:")
print(f"  train loss = ", float(mean_train_loss))
#print(f"   test loss = ", float(mean_test_loss))

In [None]:
ds = train_dset

x = []
ims = []
for (im,f) in ds:
    x.append(f)
    ims.append(im)
ims = torch.stack(ims, 0)
y = model(ims).flatten()
y = torch.sigmoid(y)

(x,y) = (np.array(x), y.detach().numpy())

In [None]:
np.mean(np.abs(x - y))

In [None]:
np.corrcoef(x, y)

In [None]:
#(x,y) = np.transpose(xy)

(fig,ax) = plt.subplots(1, 1, figsize=(5,4), dpi=288)

ax.scatter(x*100, y*100, c='k', s=0.5, alpha=0.5)
ax.plot([0,100],[0,100], 'r:', zorder=-10)
ax.set_xlim([0,100])
ax.set_ylim([0,100])
ax.set_xlabel('True Plant Fraction [%]')
ax.set_ylabel('Predicted Plant Fraction [%]')

plt.show()

In [None]:
torch.argmax(u)

In [None]:
(im,f) = ds[134]
print(f)
plt.imshow(torch.permute(im, (1,2,0)))

In [None]:
filename = filenames[11]

# (1) Read in the image:
dat = flyr.unpack(filename)
im0 = dat.optical
th0 = dat.celsius
im0 = np.flipud(np.transpose(im0, (1,0,2)))
th0 = np.flipud(np.transpose(th0, (1,0)))

# (2) Extract the yellow-blue channel:
#im_b = im0[:,:,1] / np.mean(pcv.gaussian_blur(im0, (ksize, ksize)), axis=-1)
#im_b /= 2
#im_b = (np.clip(im_b, 0, 1) * 255).astype(np.uint8)

im_b = pcv.rgb2gray_lab(rgb_img=im0, channel='b')
#im_b = pcv.hist_equalization(im_b)

# (3) Pick a threshold:
im_mask = pcv.threshold.binary(
    gray_img=im_b,
    threshold=130,
    object_type='light')

# (4) Delete the out-of-mask pieces of the original image.
im_seg = np.array(im0)
im_seg[im_mask > 0, :] = 255

(fig, axs) = plt.subplots(2, 2, figsize=(7,7), dpi=288)
axs = axs.flatten()

axs[0].imshow(im0)
axs[1].imshow(im_b, cmap='gray', vmin=0, vmax=255)
axs[2].imshow(im_mask, cmap='gray', vmin=0, vmax=255)
axs[3].imshow(im_seg, cmap='gray', vmin=0, vmax=255)

for ax in axs:
    ax.axis('off')

In [None]:
filename

In [None]:
plt.imshow(
    mpl.image.imread('/Users/nben/Desktop/monthly images/6_25/RGB/FLIR2369.jpg'))

In [None]:
plt.imshow(th0)

In [None]:
plt.imshow(im0)

In [None]:
plt.imshow(
    np.stack(
        [pcv.hist_equalization(im0[:,:,k])
         for k in (0,1,2)],
        axis=2))

In [None]:
?pcv.threshold.binary

In [None]:
plt.imshow(im)

In [None]:
b_img = pcv.rgb2gray_lab(rgb_img=im, channel='b')

In [None]:
plt.imshow(b_img, vmin=0, vmax=255)

In [None]:
thresh_mask = pcv.threshold.binary(gray_img=b_img, threshold=120, object_type='light')

In [None]:
plt.imshow(thresh_mask, cmap='gray')

In [None]:
immask = np.array(im)
immask[thresh_mask > 0, :] = 0
plt.imshow(immask)

In [None]:
hist_figure1, hist_data1 = pcv.visualize.histogram(img = b_img, hist_data=True)

In [None]:
hist_figure1

In [None]:
! pwd

In [None]:
im = mpl.image.imread('/Users/nben/Desktop/FLIR3099.png')

In [None]:
plt.imshow(im)