In [1]:
import mlflow
import torch
from torch.utils.data import DataLoader
from glob import glob
import os
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

from components.mymodel import load_model, get_model
from components.helper import train

In [2]:
#We can check whether we have gpu
DEVICE = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
# DEVICE = 'cpu'
print("Device: ", DEVICE)

Device:  cuda:0


In [3]:
import torch
from torch.utils.data import Dataset
from glob import glob
import os
import pandas as pd
from torchvision import io
from enum import Enum

class Devices(Enum):
    iphone = 'Iphone'
    oppo = 'Oppo'
    redmi = 'Redmi'
    samsung = 'Samsung'
    all = '*'

class Environments(Enum):
    indoor = 'Indoor'
    outdoor = 'Outdoor'
    all = '*'

class Imageset(Enum):
    om = 'OM'
    p = 'P'

class SoilDataset_bigset(Dataset):
    def __init__(self, imageset:Imageset, device:Devices, environment:Environments, transform=None):
        # BasePath of the dataset
        dataset_path:str = './dataset/bigset/'
        assert os.path.exists(dataset_path), f"{dataset_path=} is not exist."
        # Inside this path there must be a list of folders arange by mobile phone. Use device enum.
        # Inside those mobile phone are 2 folders indicate the environment the image was taken in. Use environment enum.
        image_folder = os.path.join(dataset_path, imageset.value, device.value, environment.value)
        self.imgs = glob(os.path.join(image_folder,'*/*'))
        print(f"Found {len(self.imgs)} images in {image_folder}.")

        # Load csv file for lookup the target value
        target_path:str = os.path.join(dataset_path,imageset.value,'meta.csv')
        self.target_df = pd.read_csv(target_path, index_col='id')

        self.signature = os.path.join(imageset.value,device.value,environment.value)
        self.transform = transform

    def get_target(self, img_path:str) -> float:
        assert len(img_path.split('/')) == 8, f"Expect img_path to have 8 folders but got {img_path=}"
        target_id = int(img_path.split('/')[6])
        return self.target_df.loc[target_id] # type:ignore
        
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        img_path = self.imgs[idx]
        y = self.get_target(img_path=img_path)
        y = torch.tensor(y)
        X = io.read_image(img_path)
        if self.transform:
            X = self.transform(X)
        return X.float(), y.float(), img_path

In [4]:
def train_model(model, dataset:SoilDataset_bigset, epochs:int, lr:float, save_path:str):
    loader = DataLoader(dataset=dataset, batch_size=200, shuffle=True, num_workers=30)
    model, train_losses = train(model, loader, epochs, lr, DEVICE, save_path)
    plt.plot(train_losses)
    plt.title(dataset.signature)
    plt.show()
    return model, train_losses

In [5]:
mlflow.set_tracking_uri("https://web-mlflow.akraradets.duckdns.org")
mlflow.get_tracking_uri()
mlflow.set_experiment(experiment_name='Soil')
mlflow.start_run(run_name="OM Bigset-resizefirst")

<ActiveRun: >

In [6]:
# from PIL import Image

# img = Image.open(dataset.imgs[100])
# # dataset.imgs[100]
# img
# plt.imshow( img_array )

In [7]:
image_set = Imageset.om
device = Devices.all
environment = Environments.all

train_losses = []
# mobilenet_v3_large
# resnet50
# efficientnet_v2_l
# alexnet

model_name = mlflow.log_param("model_name", 'alexnet')
weight_path = f'./weight/{model_name}'

params:dict = dict({
    'epochs': 1000,
    'lr': 0.001,
    'Imageset': image_set.value,
    'Device': device.value,
    'Environment': environment.value
})

mlflow.log_params(params)

preprocess = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(350),
    transforms.CenterCrop(224),
    # transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
dataset = SoilDataset_bigset(imageset=Imageset.om, device=Devices.all, environment=Environments.all, transform=preprocess)

weight_path = os.path.join(weight_path,dataset.signature)
if(os.path.exists(weight_path) == False):
    os.makedirs(weight_path)

save_path =  os.path.join(weight_path,f"{model_name}.pth")
# Continue
# model = get_model(model_name=model_name, image_set=image_set)
# From scratch
model = load_model(model_name=model_name)
# model = torch.compile(model)
model, train_loss = train_model(model, dataset=dataset, save_path=save_path, epochs=params['epochs'], lr=params['lr'])
train_losses.append(train_loss)

Found 4047 images in ./dataset/bigset/OM/*/*.
44.703086376190186 0 tensor(730.9408)




44.64417862892151 1 tensor(131.3284)




45.230419397354126 2 tensor(87.2748)




45.82435703277588 3 tensor(61.8670)




45.01281547546387 4 tensor(50.4504)




44.35214400291443 5 tensor(44.1620)




44.85190176963806 6 tensor(43.4198)




44.43047308921814 7 tensor(40.0337)




44.62518763542175 8 tensor(40.2046)




44.608874559402466 9 tensor(38.4944)




45.20341157913208 10 tensor(33.8612)




45.00702166557312 11 tensor(31.9138)




44.438230991363525 12 tensor(30.5974)




45.2685751914978 13 tensor(29.6633)




45.29035711288452 14 tensor(29.0159)




45.05148530006409 15 tensor(29.4430)




45.58117747306824 16 tensor(25.5613)




45.15455603599548 17 tensor(25.8697)




44.80406737327576 18 tensor(25.0991)




44.36273980140686 19 tensor(23.8613)




45.49970769882202 20 tensor(25.4825)




44.976202964782715 21 tensor(23.2265)




44.63121175765991 22 tensor(22.7698)




44.74037480354309 23 tensor(20.4901)




44.202494382858276 24 tensor(21.2090)




44.24297118186951 25 tensor(18.3262)




44.72164964675903 26 tensor(17.8153)




45.64935350418091 27 tensor(16.7187)




44.829657316207886 28 tensor(15.7377)




44.51461434364319 29 tensor(17.1793)




44.66601037979126 30 tensor(16.8654)




44.54109525680542 31 tensor(15.1744)




44.990129232406616 32 tensor(16.3020)




44.72747015953064 33 tensor(15.3872)




44.18075084686279 34 tensor(12.6329)




45.1493935585022 35 tensor(11.8819)




44.956571102142334 36 tensor(11.4805)




44.51082944869995 37 tensor(11.8079)




44.404218673706055 38 tensor(11.5652)




44.59766936302185 39 tensor(10.4548)




44.71978998184204 40 tensor(10.2042)




44.95344138145447 41 tensor(10.3805)




45.35662865638733 42 tensor(9.6906)




44.89081597328186 43 tensor(10.0596)




45.224894523620605 44 tensor(8.2805)




44.96916151046753 45 tensor(8.5110)




44.97727179527283 46 tensor(7.6524)




44.73147988319397 47 tensor(8.9211)




44.87804102897644 48 tensor(8.6815)




44.67917060852051 49 tensor(7.5303)




44.988996267318726 50 tensor(7.4900)




44.44933724403381 51 tensor(7.9968)




44.92501449584961 52 tensor(7.0523)




44.4626407623291 53 tensor(6.7528)




44.52277421951294 54 tensor(6.2897)




44.641605854034424 55 tensor(5.2432)




45.39827537536621 56 tensor(5.3695)




44.759543895721436 57 tensor(5.8026)




44.65455484390259 58 tensor(5.1632)




45.34510898590088 59 tensor(4.9543)




45.11558747291565 60 tensor(4.4605)




44.714481830596924 61 tensor(4.4951)




44.82256245613098 62 tensor(4.7713)




44.2913715839386 63 tensor(4.5864)




44.55144953727722 64 tensor(4.0122)




45.172199726104736 65 tensor(4.9353)




44.41841769218445 66 tensor(4.2255)




44.997318506240845 67 tensor(3.5974)




44.53584003448486 68 tensor(3.4636)




45.02940893173218 69 tensor(3.7437)




45.40541934967041 70 tensor(3.6255)




44.9139130115509 71 tensor(3.4075)




44.45180320739746 72 tensor(3.0729)




45.386693716049194 73 tensor(3.2806)




44.74017810821533 74 tensor(3.5114)




45.0582013130188 75 tensor(2.9591)




44.26872944831848 76 tensor(2.9165)




45.839937925338745 77 tensor(3.1168)




44.192814111709595 78 tensor(3.6508)




45.07268929481506 79 tensor(3.0911)




44.75404214859009 80 tensor(2.7752)




45.138657093048096 81 tensor(2.6480)




44.687793254852295 82 tensor(3.0292)




44.67903470993042 83 tensor(3.1454)




44.26197791099548 84 tensor(5.0310)




45.10988402366638 85 tensor(3.4471)




44.849488973617554 86 tensor(3.4225)




44.83774542808533 87 tensor(3.4208)




45.04637336730957 88 tensor(3.4118)




45.98921585083008 89 tensor(2.6796)




45.28853940963745 90 tensor(2.2391)




44.68880867958069 91 tensor(2.3176)




45.85479712486267 92 tensor(2.6013)




44.649394035339355 93 tensor(2.2154)




44.926584243774414 94 tensor(2.8065)




44.93347430229187 95 tensor(2.6008)




45.150370597839355 96 tensor(3.2893)




44.50854229927063 97 tensor(2.5121)




45.30114483833313 98 tensor(2.3079)




45.09585642814636 99 tensor(2.0958)




45.280834913253784 100 tensor(2.0349)




45.00998592376709 101 tensor(1.7972)




44.72337889671326 102 tensor(1.8631)




45.03383708000183 103 tensor(1.7376)




44.471091508865356 104 tensor(1.7341)




45.10652422904968 105 tensor(1.8620)




44.30458068847656 106 tensor(1.6961)




44.58238697052002 107 tensor(1.5513)




44.92825222015381 108 tensor(1.4569)




44.96873736381531 109 tensor(1.6723)




44.56617498397827 110 tensor(1.7697)




44.68644332885742 111 tensor(1.8120)




45.0668580532074 112 tensor(1.7459)




45.725762605667114 113 tensor(1.5468)




44.37791633605957 114 tensor(1.9209)




44.69333267211914 115 tensor(1.7732)




44.27245283126831 116 tensor(1.6162)




44.06512403488159 117 tensor(1.5573)




44.86022996902466 118 tensor(1.4193)




45.492722034454346 119 tensor(1.8371)




44.71197986602783 120 tensor(1.5105)




44.23534822463989 121 tensor(1.3943)




45.64758539199829 122 tensor(1.3104)




44.80809736251831 123 tensor(1.3637)




44.814422369003296 124 tensor(1.2614)




45.743921756744385 125 tensor(1.3539)




44.755951166152954 126 tensor(1.5107)




44.453399419784546 127 tensor(2.0273)




45.222617864608765 128 tensor(2.4234)




44.54410195350647 129 tensor(1.9094)




44.553548097610474 130 tensor(1.5287)




44.19206500053406 131 tensor(1.4918)




44.45829653739929 132 tensor(1.2617)




44.39392328262329 133 tensor(1.4758)




44.81990194320679 134 tensor(1.5166)




45.69951844215393 135 tensor(2.0598)




44.456400632858276 136 tensor(2.0495)




44.63743019104004 137 tensor(1.5589)




44.80612301826477 138 tensor(1.4701)




44.79983973503113 139 tensor(1.4850)




44.34563970565796 140 tensor(1.3060)




45.42669439315796 141 tensor(1.6828)




45.18252921104431 142 tensor(1.8392)




44.911064863204956 143 tensor(2.1811)




44.21333384513855 144 tensor(1.5611)




45.56359052658081 145 tensor(1.2328)




46.00407361984253 146 tensor(1.1792)




44.42733144760132 147 tensor(1.2811)




44.6161687374115 148 tensor(1.3315)




44.61905026435852 149 tensor(1.1523)




44.24516201019287 150 tensor(1.1848)




44.56145167350769 151 tensor(1.0741)




44.700767278671265 152 tensor(1.0398)




44.903467655181885 153 tensor(1.1415)




45.11295461654663 154 tensor(1.3547)




44.772443771362305 155 tensor(1.5333)




44.36402893066406 156 tensor(1.2330)




44.574585914611816 157 tensor(1.1519)




44.57452130317688 158 tensor(1.2806)




44.45260572433472 159 tensor(1.3840)




44.1210355758667 160 tensor(2.3002)




45.3221492767334 161 tensor(2.4608)




44.73396635055542 162 tensor(2.5682)




45.086456537246704 163 tensor(2.3972)




45.236016511917114 164 tensor(2.7078)




44.762155294418335 165 tensor(14.5322)




44.805408000946045 166 tensor(9.3705)




45.21682262420654 167 tensor(6.0057)




45.267683267593384 168 tensor(4.1851)




44.76121497154236 169 tensor(3.5580)




45.02547907829285 170 tensor(3.0481)




44.57739734649658 171 tensor(2.8891)




44.86013221740723 172 tensor(2.4907)




44.338966369628906 173 tensor(1.9255)




45.21996188163757 174 tensor(1.5725)




44.35354566574097 175 tensor(1.4156)




44.92816877365112 176 tensor(1.5552)




44.9728422164917 177 tensor(1.3812)




45.19112491607666 178 tensor(1.1634)




44.36181998252869 179 tensor(1.3535)




44.57973623275757 180 tensor(1.4205)




44.63870429992676 181 tensor(1.1732)




44.85123300552368 182 tensor(1.3543)




45.06493067741394 183 tensor(1.1779)




45.52864480018616 184 tensor(1.0432)




45.04552674293518 185 tensor(1.0683)




45.03289771080017 186 tensor(1.0115)




44.87040090560913 187 tensor(0.9705)




44.9497709274292 188 tensor(0.9615)




44.867520332336426 189 tensor(1.1919)




45.238638401031494 190 tensor(1.1092)




44.30668044090271 191 tensor(1.0473)




44.524500131607056 192 tensor(0.9487)




44.7880756855011 193 tensor(0.9676)




45.06849408149719 194 tensor(0.9484)




44.595155000686646 195 tensor(0.9173)




44.63373064994812 196 tensor(1.0004)




44.55176043510437 197 tensor(1.1133)




44.664244174957275 198 tensor(1.0910)




44.466254472732544 199 tensor(1.2100)




45.73988914489746 200 tensor(1.1541)




44.9560067653656 201 tensor(1.0911)




45.07030963897705 202 tensor(1.1072)




44.67632079124451 203 tensor(1.3144)




44.83812737464905 204 tensor(1.2400)




45.543158531188965 205 tensor(1.1325)




45.31908845901489 206 tensor(1.0811)




44.82167959213257 207 tensor(1.2849)




44.63957762718201 208 tensor(1.3417)




44.28152394294739 209 tensor(1.1591)




44.59199285507202 210 tensor(0.9856)




44.701757192611694 211 tensor(0.9117)




44.88970232009888 212 tensor(1.0543)


