Skip to content

Commit

Permalink
update the vision
Browse files Browse the repository at this point in the history
  • Loading branch information
YangYuqi317 authored and YangYuqi317 committed Jan 14, 2023
1 parent 6bd0ab0 commit 69c4995
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 2 deletions.
74 changes: 73 additions & 1 deletion docs/en/useful_tools.md
Expand Up @@ -183,4 +183,76 @@ def update_model(model: nn.Module, optimizer: Optimizer, pbar:Iterable, strategy

## Visualization

We designed this module to visualize the results.
We designed this module to visualize the results. This module can help to show the heat map, which is better for the result. In this module, `fiftyone` is mainly imported and we create a class named `VOXEL`.
```python
class VOXEL:

def __init__(self, dataset, name:str, persistent:bool=False, cuda:bool=True, interpreter:Interpreter=None) -> None:
self.dataset = dataset
self.name = name
self.persistent = persistent
self.cuda = cuda
self.interpreter = interpreter

if self.name not in self.loaded_datasets():
self.fo_dataset = self.create_dataset()
self.load()
else:
self.fo_dataset = fo.load_dataset(self.name)

self.view = self.fo_dataset.view()

def create_dataset(self) -> fo.Dataset:
return fo.Dataset(self.name)

def loaded_datasets(self) -> t.List:
return fo.list_datasets()

def load(self):

samples = []

for i in tqdm(range(len(self.dataset))):
path, anno = self.dataset.get_imgpath_anno_pair(i)

sample = fo.Sample(filepath=path)

# Store classification in a field name of your choice
sample["ground_truth"] = fo.Classification(label=anno)

samples.append(sample)

# Create dataset

self.fo_dataset.add_samples(samples)
self.fo_dataset.persistent = self.persistent

def predict(self, model:nn.Module, transforms, n:int=inf, name="prediction", seed=51, explain:bool=False):
model.eval()
if n < inf:
self.view = self.fo_dataset.take(n, seed=seed)

with fo.ProgressBar() as pb:
for sample in pb(self.view):
image = Image.open(sample.filepath)
image = transforms(image).unsqueeze(0)

if self.cuda:
image = image.cuda()
pred = model(image)
index = torch.argmax(pred).item()
confidence = pred[:, index].item()


sample[name] = fo.Classification(
label=str(index),
confidence=confidence
)

if self.interpreter:
heatmap = self.interpreter(image_path=sample.filepath, image_tensor=image, transforms=transforms)
sample["heatmap"] = fo.Heatmap(map=heatmap)

sample.save()
print("Finished adding predictions")
```
74 changes: 73 additions & 1 deletion docs/zh_CN/useful_tools.md
Expand Up @@ -186,4 +186,76 @@ def update_model(model: nn.Module, optimizer: Optimizer, pbar:Iterable, strategy

## 可视化

我们设计该模块将结果进行可视化
我们设计该模块将结果进行可视化,这个模块可以帮助显示热图,帮助我们更好的理解实验结果。在这个模块中,我们导入了'fiftyone',并且我们创建了一个名为'VOXEL'的类。
```python
class VOXEL:

def __init__(self, dataset, name:str, persistent:bool=False, cuda:bool=True, interpreter:Interpreter=None) -> None:
self.dataset = dataset
self.name = name
self.persistent = persistent
self.cuda = cuda
self.interpreter = interpreter

if self.name not in self.loaded_datasets():
self.fo_dataset = self.create_dataset()
self.load()
else:
self.fo_dataset = fo.load_dataset(self.name)

self.view = self.fo_dataset.view()

def create_dataset(self) -> fo.Dataset:
return fo.Dataset(self.name)

def loaded_datasets(self) -> t.List:
return fo.list_datasets()

def load(self):

samples = []

for i in tqdm(range(len(self.dataset))):
path, anno = self.dataset.get_imgpath_anno_pair(i)

sample = fo.Sample(filepath=path)

# Store classification in a field name of your choice
sample["ground_truth"] = fo.Classification(label=anno)

samples.append(sample)

# Create dataset

self.fo_dataset.add_samples(samples)
self.fo_dataset.persistent = self.persistent

def predict(self, model:nn.Module, transforms, n:int=inf, name="prediction", seed=51, explain:bool=False):
model.eval()
if n < inf:
self.view = self.fo_dataset.take(n, seed=seed)

with fo.ProgressBar() as pb:
for sample in pb(self.view):
image = Image.open(sample.filepath)
image = transforms(image).unsqueeze(0)

if self.cuda:
image = image.cuda()
pred = model(image)
index = torch.argmax(pred).item()
confidence = pred[:, index].item()


sample[name] = fo.Classification(
label=str(index),
confidence=confidence
)

if self.interpreter:
heatmap = self.interpreter(image_path=sample.filepath, image_tensor=image, transforms=transforms)
sample["heatmap"] = fo.Heatmap(map=heatmap)

sample.save()
print("Finished adding predictions")
```

0 comments on commit 69c4995

Please sign in to comment.